diff --git a/UNet_main_test.py b/UNet_main_test.py new file mode 100755 index 0000000..befa7a4 --- /dev/null +++ b/UNet_main_test.py @@ -0,0 +1,172 @@ +# UNet_main_test.py : Main Code for testing the UNet accompanying publication "Classification of prostate cancer on MRI: Deep learning vs. clinical PI-RADS assessment", Patrick Schelb, Simon Kohl, Jan Philipp Radtke MD, Manuel Wiesenfarth PhD, Philipp Kickingereder MD, Sebastian Bickelhaupt, Tristan Anselm Kuder PhD, Albrecht Stenzinger, Markus Hohenfellner MD, Heinz-Peter Schlemmer MD, PhD, Klaus H. Maier-Hein PhD, David Bonekamp MD, Radiology, [manuscript accepted for publication] +# Copyright (C) 2019 German Cancer Research Center (DKFZ) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# contact: David Bonekamp, MD, d.bonekamp@dkfz-heidelberg.de + +__author__ = "German Cancer Research Center (DKFZ)" + + +import os +import numpy as np +import pandas as pd +import torch +import argparse +from batchgenerators.transforms.abstract_transforms import Compose +from batchgenerators.transforms.spatial_transforms import SpatialTransform, Mirror +from batchgenerators.transforms.noise_transforms import RicianNoiseTransform +from batchgenerators.transforms.resample_transforms import ResampleTransform +from batchgenerators.transforms.abstract_transforms import RndTransform +from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform +from batchgenerators.transforms.sample_normalization_transforms import CutOffOutliersTransform +from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, BrightnessTransform, ContrastAugmentationTransform, GammaTransform +from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform +from UNet_net import UNetPytorch +from UNet_utils import BatchGenerator, CreateTrainValTestSplit, validate, get_class_frequencies, CrossEntropyLoss2d + + +#################################################################################################################### +### Settings +#################################################################################################################### + +parser = argparse.ArgumentParser(description='RADIOLOGY-2019 Prostate Lesion Segmentation ensemble testing script') + +parser.add_argument('--input_file', default='', type=str, help='name of input file') +parser.add_argument('--ensemble_path', default='', type=str, help='path to pretrained CNN ensemble') +parser.add_argument('--checkpoint_name', default='checkpoint_UNetPytorch.pth.tar', type=str, help='name of ensemble checkpoint files') +parser.add_argument('--output_path', default='', type=str, help='name of output folder') +parser.add_argument('--num_splits', default=10, type=int, help='split dataset in 10 folds') +parser.add_argument('--num_val_folds', default=2, type=int, help='leave x fold out for validation') +parser.add_argument('--num_test_folds', default=2, type=int, help='leave x fold out for testing') +parser.add_argument('--seed', default=42, type=int, help='seed') +parser.add_argument('--patch_size', default=(160, 160), type=tuple, help='define patch size') + + +def main(): + + # assign global args + global args + args = parser.parse_args() + + + # make a folder for the experiment + general_folder_name = args.output_path + try: + os.mkdir(general_folder_name) + except OSError: + pass + + # create train, test split, return the indices, patients in test_split wont be seen during whole training + train_idx, val_idx, test_idx = CreateTrainValTestSplit(HistoFile_path=args.input_file, num_splits=args.num_splits, num_test_folds=args.num_test_folds, + num_val_folds=args.num_val_folds, seed=args.seed) + + print('size of training set {}'.format(len(train_idx))) + print('size of validation set {}'.format(len(val_idx))) + print('size of test set {}'.format(len(test_idx))) + + train_idx = train_idx + val_idx + + # data loading + Data = ProstataData(args.input_file) #For details on this class see README + + + # get class frequencies + print('calculating class frequencie') + + Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC, \ + Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2 \ + , ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std \ + = get_class_frequencies(Data, train_idx, patch_size=args.patch_size) + + print ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std + + print('ADC', Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC) + print('T2', Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2) + + all_ADC = np.float(Background_frequencie_ADC + Prostate_frequencie_ADC + Tumor_frequencie_ADC) + all_T2 = np.float(Background_frequencie_T2 + Prostate_frequencie_T2 + Tumor_frequencie_T2) + + print all_ADC + print all_T2 + + W1_ADC = 1 / (Background_frequencie_ADC / all_ADC) ** 0.25 + W2_ADC = 1 / (Prostate_frequencie_ADC / all_ADC) ** 0.25 + W3_ADC = 1 / (Tumor_frequencie_ADC / all_ADC) ** 0.25 + + Wa_ADC = W1_ADC / (W1_ADC + W2_ADC + W3_ADC) + Wb_ADC = W2_ADC / (W1_ADC + W2_ADC + W3_ADC) + Wc_ADC = W3_ADC / (W1_ADC + W2_ADC + W3_ADC) + + print 'Weights ADC', Wa_ADC, Wb_ADC, Wc_ADC + + weight_ADC = (Wa_ADC, Wb_ADC, Wc_ADC) + + W1_T2 = 1 / (Background_frequencie_T2 / all_T2) ** 0.25 + W2_T2 = 1 / (Prostate_frequencie_T2 / all_T2) ** 0.25 + W3_T2 = 1 / (Tumor_frequencie_T2 / all_T2) ** 0.25 + + Wa_T2 = W1_T2 / (W1_T2 + W2_T2 + W3_T2) + Wb_T2 = W2_T2 / (W1_T2 + W2_T2 + W3_T2) + Wc_T2 = W3_T2 / (W1_T2 + W2_T2 + W3_T2) + + print 'Weights T2', Wa_T2, Wb_T2, Wc_T2 + + weight_T2 = (Wa_T2, Wb_T2, Wc_T2) + + criterion_ADC = CrossEntropyLoss2d(weight=torch.FloatTensor(weight_ADC)).cuda() + criterion_T2 = CrossEntropyLoss2d(weight=torch.FloatTensor(weight_T2)).cuda() + + Center_Crop = CenterCropTransform(args.patch_size) + count = 0 + + for folder, subfolders, files in sorted(os.walk(args.ensemble_path)): + for file in files: + if file.startswith(args.checkpoint_name): + epoch = 0 + + # define model + Net = UNetPytorch(in_shape=(3, args.patch_size[0], args.patch_size[1])) + model = Net.cuda() + + folder_name_for_single_nets = general_folder_name + '/Net_{}'.format(count) + + try: + os.mkdir(folder_name_for_single_nets) + except OSError: + pass + + model_path = os.path.join(os.path.abspath(folder), file) + checkpoint = torch.load(model_path) + model.load_state_dict(checkpoint['state_dict']) + test_loader = BatchGenerator(Data, BATCH_SIZE=0, split_idx=test_idx, seed=args.seed, + ProbabilityTumorSlices=None, epoch=epoch, test=True, + ADC_mean=ADC_mean, ADC_std=ADC_std, BVAL_mean=BVAL_mean, BVAL_std= + BVAL_std, T2_mean=T2_mean, T2_std=T2_std + ) + + test_loss = validate(test_loader, model, + folder_name=folder_name_for_single_nets, criterion_ADC=criterion_ADC, criterion_T2=criterion_T2, + split_ixs=test_idx, epoch=epoch, workers=1, Center_Crop=Center_Crop, seed=args.seed, test=True) + + TestCSV = pd.DataFrame({'test_loss': [test_loss]}) + TestCSV.to_csv(folder_name_for_single_nets + '/TestCSV.csv') + count += 1 + +if __name__ == '__main__': + main() + + + + diff --git a/UNet_main_train.py b/UNet_main_train.py new file mode 100755 index 0000000..7682b3e --- /dev/null +++ b/UNet_main_train.py @@ -0,0 +1,304 @@ +# UNet_main_train.py : Main Code for training the UNet accompanying publication "Classification of prostate cancer on MRI: Deep learning vs. clinical PI-RADS assessment", Patrick Schelb, Simon Kohl, Jan Philipp Radtke MD, Manuel Wiesenfarth PhD, Philipp Kickingereder MD, Sebastian Bickelhaupt, Tristan Anselm Kuder PhD, Albrecht Stenzinger, Markus Hohenfellner MD, Heinz-Peter Schlemmer MD, PhD, Klaus H. Maier-Hein PhD, David Bonekamp MD, Radiology, [manuscript accepted for publication] +# Copyright (C) 2019 German Cancer Research Center (DKFZ) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# contact: David Bonekamp, MD, d.bonekamp@dkfz-heidelberg.de + +__author__ = "German Cancer Research Center (DKFZ)" + + +import numpy as np +import pandas as pd +import os +import torch +import argparse +from batchgenerators.transforms.abstract_transforms import Compose +from batchgenerators.transforms.spatial_transforms import SpatialTransform, Mirror +from batchgenerators.transforms.noise_transforms import RicianNoiseTransform +from batchgenerators.transforms.resample_transforms import ResampleTransform +from batchgenerators.transforms.abstract_transforms import RndTransform +from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform +from batchgenerators.transforms.sample_normalization_transforms import CutOffOutliersTransform +from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, BrightnessTransform, ContrastAugmentationTransform, GammaTransform +from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform +from UNet_net import UNetPytorch +from UNet_utils import train, save_checkpoint, adjust_learning_rate\ + ,BatchGenerator, CreateTrainValTestSplit, CrossEntropyLoss2d, get_class_frequencies, clear_image_data, \ + split_training, validate, \ + get_oversampling + + +#################################################################################################################### +### Settings +#################################################################################################################### + +parser = argparse.ArgumentParser(description='RADIOLOGY-2019 Prostate Lesion Segmentation training script') + +parser.add_argument('--input_file', default='', type=str, help='name of input file') +parser.add_argument('--output_path', default='', type=str, help='name of output folder') +parser.add_argument('--arch',default='UNet') +parser.add_argument('--workers', default=8, type=int, metavar='N', help='number of data loading workers') +parser.add_argument('--epochs', default=80, type=int, metavar='N', help='number of total epochs to run') +parser.add_argument('--b', default=32, type=int, help='mini-batch size') +parser.add_argument('--lr', default=0.00015, type=float, help='initial learning rate') +parser.add_argument('--weight_decay', default=1e-5, type=float, help='weight decay') +parser.add_argument('--num_splits', default=10, type=int, help='split dataset in 10 folds') +parser.add_argument('--num_val_folds', default=2, type=int, help='leave x fold out for validation') +parser.add_argument('--num_test_folds', default=2, type=int, help='leave x fold out for testing') +parser.add_argument('--seed', default=42, type=int, help='seed') +parser.add_argument('--patch_size', default=(160, 160), type=tuple, help='define patch size') +parser.add_argument('--p', default=1, type=float, help='probability for spatial transform') +parser.add_argument('--cv_number', default=16, type=int, help='number of cross validation loops') +parser.add_argument('--cv_start', default=0, type=int, help='') +parser.add_argument('--cv_end', default=16, type=int, help='') + + +def main(): + + # assign global args + global args + args = parser.parse_args() + + + # make a folder for the experiment + general_folder_name = args.output_path + try: + os.mkdir(general_folder_name) + except OSError: + pass + + + # create train, test split, return the indices, patients in test_split wont be seen during whole training + train_idx, val_idx, test_idx = CreateTrainValTestSplit(HistoFile_path=args.input_file, num_splits=args.num_splits, num_test_folds=args.num_test_folds, + num_val_folds=args.num_val_folds, seed=args.seed) + + IDs = train_idx + val_idx + + print('size of training set {}'.format(len(train_idx))) + print('size of validation set {}'.format(len(val_idx))) + print('size of test set {}'.format(len(test_idx))) + + + # data loading + Data = ProstataData(args.input_file) #For details on this class see README + + # train and validate + for cv in range(args.cv_start, args.cv_end): + best_epoch = 0 + train_loss = [] + val_loss = [] + + # define patients for training and validation + train_idx, val_idx = split_training(IDs, len_val=62, cv=cv, cv_runs=args.cv_number) + + oversampling_factor, Slices_total, Natural_probability_tu_slice, Natural_probability_PRO_slice = get_oversampling(Data, train_idx=sorted(train_idx), Batch_Size=args.b, patch_size=args.patch_size) + + training_batches = Slices_total / args.b + + lr = args.lr + base_lr = args.lr + args.seed += 1 + + print('train_idx', train_idx, len(train_idx)) + print('val_idx', val_idx, len(val_idx)) + + # get class frequencies + print('calculating class frequencie') + + Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC,\ + Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2\ + , ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std \ + = get_class_frequencies(Data, train_idx, patch_size=args.patch_size) + + print ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std + + print('ADC', Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC) + print('T2', Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2) + + all_ADC = np.float(Background_frequencie_ADC + Prostate_frequencie_ADC + Tumor_frequencie_ADC) + all_T2 = np.float(Background_frequencie_T2 + Prostate_frequencie_T2 + Tumor_frequencie_T2) + + print all_ADC + print all_T2 + + W1_ADC = 1 / (Background_frequencie_ADC / all_ADC) ** 0.25 + W2_ADC = 1 / (Prostate_frequencie_ADC / all_ADC) ** 0.25 + W3_ADC = 1 / (Tumor_frequencie_ADC / all_ADC) ** 0.25 + + Wa_ADC = W1_ADC / (W1_ADC + W2_ADC + W3_ADC) + Wb_ADC = W2_ADC / (W1_ADC + W2_ADC + W3_ADC) + Wc_ADC = W3_ADC / (W1_ADC + W2_ADC + W3_ADC) + + print 'Weights ADC', Wa_ADC, Wb_ADC, Wc_ADC + + weight_ADC = (Wa_ADC, Wb_ADC, Wc_ADC) + + W1_T2 = 1 / (Background_frequencie_T2 / all_T2) ** 0.25 + W2_T2 = 1 / (Prostate_frequencie_T2 / all_T2) ** 0.25 + W3_T2 = 1 / (Tumor_frequencie_T2 / all_T2) ** 0.25 + + Wa_T2 = W1_T2 / (W1_T2 + W2_T2 + W3_T2) + Wb_T2 = W2_T2 / (W1_T2 + W2_T2 + W3_T2) + Wc_T2 = W3_T2 / (W1_T2 + W2_T2 + W3_T2) + + print 'Weights T2', Wa_T2, Wb_T2, Wc_T2 + + weight_T2 = (Wa_T2, Wb_T2, Wc_T2) + + # define model + Net = UNetPytorch(in_shape=(3, args.patch_size[0], args.patch_size[1])) + Net_Name = 'UNetPytorch' + model = Net.cuda() + + # model parameter + optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=args.weight_decay) + criterion_ADC = CrossEntropyLoss2d(weight=torch.FloatTensor(weight_ADC)).cuda() + criterion_T2 = CrossEntropyLoss2d(weight=torch.FloatTensor(weight_T2)).cuda() + + + # new folder name for cv + folder_name = general_folder_name + '/CV_{}'.format(cv) + try: + os.mkdir(folder_name) + except OSError: + pass + + checkpoint_file = folder_name + '/checkpoint_' + '{}.pth.tar'.format(Net_Name) + + # augmentation + for epoch in range(args.epochs): + torch.manual_seed(args.seed + epoch + cv) + np.random.seed(epoch + cv) + np.random.shuffle(train_idx) + + if epoch == 0: + my_transforms = [] + spatial_transform = SpatialTransform(args.patch_size, np.array(args.patch_size) // 2, + do_elastic_deform=True, alpha=(100., 450.), + sigma=(13., 17.), + do_rotation=True, angle_z=(-np.pi / 2., np.pi / 2.), + do_scale=True, scale=(0.75, 1.25), + border_mode_data='constant', border_cval_data=0, + order_data=3, + random_crop=True) + resample_transform = ResampleTransform(zoom_range=(0.7, 1.3)) + brightness_transform = BrightnessTransform(0.0, 0.1, True) + my_transforms.append(resample_transform) + my_transforms.append(ContrastAugmentationTransform((0.75, 1.25), True)) + my_transforms.append(brightness_transform) + my_transforms.append(Mirror((2, 3))) + all_transforms = Compose(my_transforms) + sometimes_spatial_transforms = RndTransform(spatial_transform, prob=args.p, + alternative_transform=CenterCropTransform( + args.patch_size)) + sometimes_other_transforms = RndTransform(all_transforms, prob=1.0) + final_transform = Compose([sometimes_spatial_transforms, sometimes_other_transforms]) + Center_Crop = CenterCropTransform(args.patch_size) + + if epoch == 30: + my_transforms = [] + spatial_transform = SpatialTransform(args.patch_size, np.array(args.patch_size) // 2, + do_elastic_deform=True, alpha=(0., 250.), + sigma=(11., 14.), + do_rotation=True, angle_z=(-np.pi / 2., np.pi / 2.), + do_scale=True, scale=(0.85, 1.15), + border_mode_data='constant', border_cval_data=0, + order_data=3, + random_crop=True) + resample_transform = ResampleTransform(zoom_range=(0.8, 1.2)) + brightness_transform = BrightnessTransform(0.0, 0.1, True) + my_transforms.append(resample_transform) + my_transforms.append(ContrastAugmentationTransform((0.85, 1.15), True)) + my_transforms.append(brightness_transform) + all_transforms = Compose(my_transforms) + sometimes_spatial_transforms = RndTransform(spatial_transform, prob=args.p, + alternative_transform=CenterCropTransform( + args.patch_size)) + sometimes_other_transforms = RndTransform(all_transforms, prob=1.0) + final_transform = Compose([sometimes_spatial_transforms, sometimes_other_transforms]) + Center_Crop = CenterCropTransform(args.patch_size) + + if epoch == 50: + my_transforms = [] + spatial_transform = SpatialTransform(args.patch_size, np.array(args.patch_size) // 2, + do_elastic_deform=True, alpha=(0., 150.), + sigma=(10., 12.), + do_rotation=True, angle_z=(-np.pi / 2., np.pi / 2.), + do_scale=True, scale=(0.85, 1.15), + border_mode_data='constant', border_cval_data=0, + order_data=3, + random_crop=False) + resample_transform = ResampleTransform(zoom_range=(0.9, 1.1)) + brightness_transform = BrightnessTransform(0.0, 0.1, True) + my_transforms.append(resample_transform) + my_transforms.append(ContrastAugmentationTransform((0.95, 1.05), True)) + my_transforms.append(brightness_transform) + all_transforms = Compose(my_transforms) + sometimes_spatial_transforms = RndTransform(spatial_transform, prob=args.p, + alternative_transform=CenterCropTransform( + args.patch_size)) + sometimes_other_transforms = RndTransform(all_transforms, prob=1.0) + final_transform = Compose([sometimes_spatial_transforms, sometimes_other_transforms]) + Center_Crop = CenterCropTransform(args.patch_size) + + + train_loader = BatchGenerator(Data, BATCH_SIZE=args.b, split_idx=train_idx, seed=args.seed, + ProbabilityTumorSlices=oversampling_factor, epoch=epoch, + ADC_mean=ADC_mean, ADC_std=ADC_std, BVAL_mean=BVAL_mean, BVAL_std= + BVAL_std, T2_mean=T2_mean, T2_std=T2_std) + + val_loader = BatchGenerator(Data, BATCH_SIZE=0, split_idx=val_idx, seed=args.seed, + ProbabilityTumorSlices=None, epoch=epoch, test=True, + ADC_mean=ADC_mean, ADC_std=ADC_std, BVAL_mean=BVAL_mean, BVAL_std= + BVAL_std, T2_mean=T2_mean, T2_std=T2_std + ) + + + #train on training set + train_losses = train(train_loader=train_loader, model=model, optimizer=optimizer, + criterion_ADC=criterion_ADC, criterion_T2=criterion_T2, + final_transform=final_transform, workers=args.workers, seed=args.seed, + training_batches=training_batches) + train_loss.append(train_losses) + + + # evaluate on validation set + val_losses = validate(val_loader=val_loader, model=model, folder_name=folder_name, + criterion_ADC =criterion_ADC, criterion_T2=criterion_T2, split_ixs=val_idx, + epoch=epoch, workers=1, Center_Crop=Center_Crop, seed=args.seed) + val_loss.append(val_losses) + + + # write TrainingsCSV to folder name + TrainingsCSV = pd.DataFrame({'train_loss': train_loss, 'val_loss': val_loss}) + TrainingsCSV.to_csv(folder_name + '/TrainingsCSV.csv') + + if val_losses <= min(val_loss): + best_epoch = epoch + print 'best epoch', epoch + save_checkpoint({'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict() + }, filename=checkpoint_file) + + optimizer, lr = adjust_learning_rate(optimizer, base_lr, epoch) + + # delete all output except best epoch + clear_image_data(folder_name, best_epoch, epoch) + + +if __name__ == '__main__': + main() + + diff --git a/UNet_net.py b/UNet_net.py new file mode 100755 index 0000000..beb4379 --- /dev/null +++ b/UNet_net.py @@ -0,0 +1,126 @@ +# UNet_net.py : CNN architecture accompanying publication "Classification of prostate cancer on MRI: Deep learning vs. clinical PI-RADS assessment", Patrick Schelb, Simon Kohl, Jan Philipp Radtke MD, Manuel Wiesenfarth PhD, Philipp Kickingereder MD, Sebastian Bickelhaupt, Tristan Anselm Kuder PhD, Albrecht Stenzinger, Markus Hohenfellner MD, Heinz-Peter Schlemmer MD, PhD, Klaus H. Maier-Hein PhD, David Bonekamp MD, Radiology, [manuscript accepted for publication] +# Copyright (C) 2019 German Cancer Research Center (DKFZ) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# contact: David Bonekamp, MD, d.bonekamp@dkfz-heidelberg.de + +__author__ = "German Cancer Research Center (DKFZ)" + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# modified Unet implementation from https://www.kaggle.com/c/carvana-image-masking-challenge/discussion/37208 + +BN_EPS = 1e-5 + +class UNetPytorch (nn.Module): + def __init__(self, in_shape): + super(UNetPytorch, self).__init__() + C, H, W = in_shape + + self.down2 = StackEncoder(C, 64, kernel_size=3) + self.down3 = StackEncoder(64, 128, kernel_size=3) + self.down4 = StackEncoder(128, 256, kernel_size=3) + self.down5 = StackEncoder(256, 512, kernel_size=3) + self.down6 = StackEncoder(512, 1024, kernel_size=3) + + self.center = nn.Sequential( + ConvRelu2d(1024, 1024, kernel_size=3, padding=1, stride=1), + ConvRelu2d(1024, 1024, kernel_size=3, padding=1, stride=1), + ConvRelu2d(1024, 1024, kernel_size=3, padding=1, stride=1)) + + self.up6 = StackDecoder(1024, 1024, 512, kernel_size=3) + self.up5 = StackDecoder(512, 512, 256, kernel_size=3) + self.up4 = StackDecoder(256, 256, 128, kernel_size=3) + self.up3 = StackDecoder(128, 128, 64, kernel_size=3) + self.up2 = StackDecoder(64, 64, 32, kernel_size=3) + + self.classify = nn.Conv2d(32, 3, kernel_size=1, padding=0, stride=1, bias=True) + + + def forward(self, input): + out = input + + down2, out = self.down2(out) + down3, out = self.down3(out) + down4, out = self.down4(out) + down5, out = self.down5(out) + down6, out = self.down6(out) + pass + + out = self.center(out) + + out = self.up6(down6, out) + out = self.up5(down5, out) + out = self.up4(down4, out) + out = self.up3(down3, out) + out = self.up2(down2, out) + + out = self.classify(out) + + return out + + +class ConvRelu2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, stride=1, groups=1, is_relu=True, is_bn=True): + super(ConvRelu2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, groups=groups, bias=False) + self.relu = nn.ReLU(inplace=True) + self.bn = nn.BatchNorm2d(out_channels, eps=BN_EPS) + if is_relu is False: self.relu = None + if is_bn is False: self.bn = None + + def forward(self, input): + convoluted = self.conv(input) + if self.relu is not None: convoluted = self.relu(convoluted) + if self.bn is not None: convoluted = self.bn(convoluted) + return convoluted + +class StackEncoder (nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3): + super(StackEncoder, self).__init__() + padding=(kernel_size-1)//2 + + self.encode = nn.Sequential( + ConvRelu2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1), + ConvRelu2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1), + ConvRelu2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1)) + + def forward(self, input): + encoded = self.encode(input) + max_pooled = F.max_pool2d(encoded, kernel_size=2, stride=2) + + return encoded, max_pooled + + +class StackDecoder (nn.Module): + def __init__(self, in_channels_down, in_channels, out_channels, kernel_size=3): + super(StackDecoder, self).__init__() + padding=(kernel_size-1)//2 + + self.decode = nn.Sequential( + ConvRelu2d(in_channels_down + in_channels, out_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1), + ConvRelu2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1), + ConvRelu2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1)) + + def forward(self, down_input, input): + N, C, H, W = down_input.size() + upsampled = F.interpolate(input, size=(H,W) ,mode='bilinear', align_corners=True) + upsampled = torch.cat([upsampled, down_input], 1) + decoded = self.decode(upsampled) + return decoded + diff --git a/UNet_utils.py b/UNet_utils.py new file mode 100755 index 0000000..12ace95 --- /dev/null +++ b/UNet_utils.py @@ -0,0 +1,650 @@ +# UNet_utils.py : Utilities Code (function definitions) accompanying publication "Classification of prostate cancer on MRI: Deep learning vs. clinical PI-RADS assessment", Patrick Schelb, Simon Kohl, Jan Philipp Radtke MD, Manuel Wiesenfarth PhD, Philipp Kickingereder MD, Sebastian Bickelhaupt, Tristan Anselm Kuder PhD, Albrecht Stenzinger, Markus Hohenfellner MD, Heinz-Peter Schlemmer MD, PhD, Klaus H. Maier-Hein PhD, David Bonekamp MD, Radiology, [manuscript accepted for publication] +# Copyright (C) 2019 German Cancer Research Center (DKFZ) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# contact: David Bonekamp, MD, d.bonekamp@dkfz-heidelberg.de + +__author__ = "German Cancer Research Center (DKFZ)" + +import pandas as pd +import os +import SimpleITK as sitk +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.dataloading.data_loader import DataLoaderBase +from builtins import object +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import torch +import shutil +import nrrd +from batchgenerators.augmentations.crop_and_pad_augmentations import center_crop +from batchgenerators.transforms.sample_normalization_transforms import CutOffOutliersTransform +from batchgenerators.augmentations.normalizations import cut_off_outliers + + +class CrossEntropyLoss2d(nn.Module): + def __init__(self, weight=None, size_average=True, reduce=True): + super(CrossEntropyLoss2d, self).__init__() + self.nll_loss = nn.NLLLoss2d(weight=weight, size_average=size_average, reduce=reduce) + + def forward(self, inputs, targets): + return self.nll_loss(F.log_softmax(inputs), targets) + + +def ToTensor(batch): + + image, label = batch['data'], batch['seg'] + + data = torch.from_numpy(image[:,:,:,:]) + seg = torch.from_numpy(label[:,0,:,:]) + seg_T2 = torch.from_numpy(label[:,1,:,:]) + + return {'data': data, + 'seg': seg, + 'seg_T2': seg_T2} + + + +def train(train_loader, model, optimizer, criterion_ADC, criterion_T2, final_transform, workers, seed, + training_batches): + + train_losses = AverageMeter() + np.random.seed(seed) + seeds = np.random.choice(seed, workers, False, None) + model.train() + multithreaded_generator = MultiThreadedAugmenter(train_loader, final_transform, workers, 2, seeds=seeds) + torch.cuda.empty_cache() + + for i in range(training_batches): + print('Batch: [{0}/{1}]'.format(i +1, training_batches)) + batch = multithreaded_generator.next() + TensorBatch = ToTensor(batch) + target = TensorBatch['seg'].cuda() + target_T2 = TensorBatch['seg_T2'].cuda() + input_var = torch.autograd.Variable(TensorBatch['data'], requires_grad=True).cuda(async=True) + input_var = input_var.float() + target_var = torch.autograd.Variable(target) + target_var = target_var.long() + target_var_T2 = torch.autograd.Variable(target_T2) + target_var_T2 = target_var_T2.long() + optimizer.zero_grad() + output = model(input_var) + loss_ADC = criterion_ADC(output, target_var) + loss_T2 = criterion_T2(output, target_var_T2) + loss = (loss_ADC + loss_T2) / 2. + loss.backward() + optimizer.step() + train_losses.update(loss.item()) + print 'train_loss', loss.item() + + torch.cuda.empty_cache() + + return train_losses.avg + + + +def validate(val_loader, model, epoch, criterion_ADC, criterion_T2 ,split_ixs, Center_Crop, workers, seed, folder_name, + test=False): + + val_losses = AverageMeter() + seeds = np.random.choice(seed, workers, False, None) + torch.cuda.empty_cache() + model.eval() + multithreaded_generator = MultiThreadedAugmenter(val_loader, Center_Crop, workers, 2, seeds=seeds) + + for i in range(len(split_ixs)): + patient = split_ixs[i] + print 'patient', patient + batch = multithreaded_generator.next() + TensorBatch = ToTensor(batch) + target = TensorBatch['seg'].cuda() + target_T2 = TensorBatch['seg_T2'].cuda() + input_var = torch.autograd.Variable(TensorBatch['data'], volatile=True).cuda(async=True) + input_var = input_var.float() + target_var = torch.autograd.Variable(target, volatile=True) + target_var = target_var.long() + target_var_T2 = torch.autograd.Variable(target_T2, volatile=True) + target_var_T2 = target_var_T2.long() + output = model(input_var) + probs = F.softmax(output) + loss_ADC = criterion_ADC(output, target_var) + loss_T2 = criterion_T2(output, target_var_T2) + loss = (loss_ADC + loss_T2) / 2. + val_losses.update(loss.item()) + if test == False: + print 'val_loss', loss.item() + else: + print 'test_loss', loss.item() + + image = (input_var.data).cpu().numpy() + Mprobs = (probs.data).cpu().numpy() + fprobs = (probs.data).cpu().numpy() + segmentation = (target).cpu().numpy() + segmentation_T2 = (target_T2).cpu().numpy() + label = np.where(segmentation == 2, 1, 0) + label_T2 = np.where(segmentation_T2 == 2, 1, 0) + label = np.uint8(label) + label_T2 = np.uint8(label_T2) + PRO = np.where(segmentation == 1, 1, 0) + PRO = np.uint8(PRO) + PRO_T2 = np.where(segmentation_T2 == 1, 1, 0) + PRO_T2 = np.uint8(PRO_T2) + + fprobs[:, 0, :, :] = fprobs[:, 0, :, :] == np.amax(fprobs, axis=1) + fprobs[:, 1, :, :] = fprobs[:, 1, :, :] == np.amax(fprobs, axis=1) + fprobs[:, 2, :, :] = fprobs[:, 2, :, :] == np.amax(fprobs, axis=1) + + ProstateOut = fprobs[:, 1, :, :] + TumorOut = fprobs[:, 2, :, :] + + probability_map_back = Mprobs[:, 0, :, :] + probability_map_pro = Mprobs[:, 1, :, :] + probability_map_tu = Mprobs[:, 2, :, :] + + if test == False: + try: + os.mkdir(folder_name + '/Val_Images') + except OSError: + pass + + try: + os.mkdir(folder_name + '/Val_Images/Epoch_{}'.format(epoch)) + except OSError: + pass + + save_images_to = folder_name + '/Val_Images/Epoch_{}/Patient_{}'.format(epoch, patient) + try: + os.mkdir(save_images_to) + except OSError: + pass + + else: + try: + os.mkdir(folder_name + '/Test_Images') + except OSError: + pass + + save_images_to = folder_name + '/Test_Images/Patient_{}'.format(patient) + try: + os.mkdir(save_images_to) + except OSError: + pass + + + ADCimage = image[:, 0, :, :] + BVALimage = image[:, 1, :, :] + T2image = image[:, 2, :, :] + + TumorOut = sitk.GetImageFromArray(np.uint8(TumorOut)) + ProstateOut = sitk.GetImageFromArray(np.uint8(ProstateOut)) + ADCimg = sitk.GetImageFromArray(ADCimage) + BVALimg = sitk.GetImageFromArray(BVALimage) + T2img = sitk.GetImageFromArray(T2image) + seg = sitk.GetImageFromArray(label) + seg_T2 = sitk.GetImageFromArray(label_T2) + pro = sitk.GetImageFromArray(PRO) + pro_T2 = sitk.GetImageFromArray(PRO_T2) + probsBack = sitk.GetImageFromArray(probability_map_back) + probsPRO = sitk.GetImageFromArray(probability_map_pro) + probsTU = sitk.GetImageFromArray(probability_map_tu) + + save(TumorOut, save_images_to + '/Tumor_Output.nrrd', Mask=True) + save(ProstateOut, save_images_to + '/Prostate_Output.nrrd', Mask=True) + save(ADCimg, save_images_to + '/ADCImage.nrrd') + save(BVALimg, save_images_to + '/BVALImage.nrrd') + save(T2img, save_images_to + '/T2Image.nrrd') + + save(seg, save_images_to + '/Label.nrrd', Mask=True) + save(seg_T2, save_images_to + '/Label_T2.nrrd', Mask=True) + save(pro, save_images_to + '/Pro_Label.nrrd', Mask=True) + save(pro_T2, save_images_to + '/Pro_Label_T2.nrrd', Mask=True) + save(probsBack, save_images_to + '/ProbabilityMapBack.nrrd') + save(probsTU, save_images_to + '/ProbabilityMapTU.nrrd') + save(probsPRO, save_images_to + '/ProbabilityMapPRO.nrrd') + torch.cuda.empty_cache() + + return val_losses.avg + + +def clear_image_data(folder_name, best_epoch, epoch): + for e in range(epoch+1): + if e == best_epoch: + print('best epoch') + else: + path_name = folder_name + '/Val_Images/Epoch_{}/'.format(e) + try: + shutil.rmtree(path_name) + except OSError: + pass + + +def save_checkpoint(state, filename): + torch.save(state, filename) + + +class AverageMeter(object): + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def adjust_learning_rate(optimizer, arg_lr, epoch): + + lr = np.float32(arg_lr * 0.98 ** epoch) + + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return optimizer, lr + + +def resample(fixed, target_resample_resolution, Mask=False): + + if Mask == True: + Interpolator = sitk.sitkNearestNeighbor + else: + Interpolator = sitk.sitkBSpline + + fixed_spacing = fixed.GetSpacing() + fixed_spacing = np.array(fixed_spacing) + fixed_size = fixed.GetSize() + fixed_size = np.array(fixed_size) + + image_size = fixed_size * (fixed_spacing / target_resample_resolution) + image_size = np.around(image_size) + image_size = image_size.astype(np.uint32).tolist() + + resample = sitk.ResampleImageFilter() + resample.SetOutputSpacing(target_resample_resolution) + resample.SetInterpolator(Interpolator) + resample.SetOutputOrigin(fixed.GetOrigin()) + resample.SetOutputDirection(fixed.GetDirection()) + resample.SetSize(image_size) + + out = resample.Execute(fixed) + + return out + + +class BatchGenerator(DataLoaderBase): + + def __init__(self, data, BATCH_SIZE, split_idx, seed, ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std, + ProbabilityTumorSlices=None, epoch=None, test=False): + super(self.__class__, self).__init__(data=data, BATCH_SIZE=BATCH_SIZE, seed=False, num_batches=None) + self._split_idx = split_idx + self._ProbabilityTumorSlices = ProbabilityTumorSlices + self._epoch = epoch + self._s = 0 + self._count = 0 + self.test = test + self.seed = seed + self.ADC_mean = ADC_mean + self.BVAL_mean = BVAL_mean + self.T2_mean = T2_mean + self.ADC_std = ADC_std + self.BVAL_std = BVAL_std + self.T2_std = T2_std + + def generate_train_batch(self): + + channels_img = 3 + channels_label = 2 + img_size = 200 + + if self.test == True: + + img = np.empty((self.BATCH_SIZE, channels_img, img_size, img_size)) + label = np.empty((self.BATCH_SIZE, channels_label, img_size, img_size)) + + if self._count < len(self._split_idx): + idx = self._split_idx[self._count] + z_dim = self._data[idx]['image'].shape[3] + self.BATCH_SIZE = z_dim + + custom_batch = (z_dim / self.BATCH_SIZE)*self.BATCH_SIZE + + + if self._count < len(self._split_idx): + img = self._data[idx]['image'] + img = img.transpose((3,0,1,2)) + label = self._data[idx]['label'] + label = label.transpose((3,0,1,2)) + self._count += 1 + + + else: + for b in range(self.BATCH_SIZE): + if self._s < custom_batch: + img[b, :, :, :] = self._data[idx]['image'][:channels_img, :, :, self._s] + label[b, :, :, :] = self._data[idx]['label'][:channels_label, :, :, self._s] + self._s += 1 + else: + self._s = 0 + self._count += 1 + self._batches_generated = self._count + if self._count < len(self._split_idx): + idx = self._split_idx[self._batches_generated] + img[b, :, :, :] = self._data[idx]['image'][:channels_img, :, :, self._s] + label[b, :, :, :] = self._data[idx]['label'][:channels_label, :, :, self._s] + self._s += 1 + else: + pass + + else: + pass + + else: + + idx = np.random.choice(self._split_idx, self.BATCH_SIZE, False, None) + img = np.empty((self.BATCH_SIZE, channels_img, img_size, img_size)) + label = np.empty((self.BATCH_SIZE, channels_label, img_size, img_size)) + + for b in range(self.BATCH_SIZE): + + if self._ProbabilityTumorSlices is not None: + LabelData = self._data[idx[b]]['label'] + z_dim = self._data[idx[b]]['image'].shape[3] + CancerSlices = [] + for Slice in range(z_dim): + + bool = np.where(LabelData[:, :, :, Slice] == 2, True, False) + + if bool.any() == True: + CancerSlices.append(Slice) + + if sum(CancerSlices) is not 0: + CancerSlices = np.array(CancerSlices) + totalTumorSliceProb = float(self._ProbabilityTumorSlices) + totalOtherSliceProb = float(1) - float(totalTumorSliceProb) + + ProbabilityMap = np.array(np.zeros(z_dim)) + ProbabilityMap[CancerSlices] = self._ProbabilityTumorSlices / float(len(CancerSlices)) + + NoTumorSlices = float(z_dim - len(CancerSlices)) + ProbPerNoTumorSlice = totalOtherSliceProb / NoTumorSlices + + ProbabilityMap = [ProbPerNoTumorSlice if g == 0 else g for g in ProbabilityMap] + + else: + ProbabilityMap = np.zeros(z_dim) + ProbabilityMap = [float(1) / float(z_dim) if g == 0 else g for g in ProbabilityMap] + + else: + ProbabilityMap = np.zeros(z_dim) + ProbabilityMap = [float(1) / float(z_dim) if g == 0 else g for g in ProbabilityMap] + + randint = np.random.choice(z_dim, p=ProbabilityMap) + + img[b, :, :, :] = self._data[idx[b]]['image'][:, :, :, randint] + label[b, :, :, :] = self._data[idx[b]]['label'][:, :, :, randint] + + # cut off outliers before image normalization + + img = np.nan_to_num(img) + img = cut_off_outliers(img, percentile_lower=0.2, percentile_upper=99.8, per_channel=True) + + img[:, 0, :, :] = (img[:, 0, :, :] - np.float(self.ADC_mean)) / np.float(self.ADC_std) + img[:, 1, :, :] = (img[:, 1, :, :] - np.float(self.BVAL_mean)) / np.float(self.BVAL_std) + img[:, 2, :, :] = (img[:, 2, :, :] - np.float(self.T2_mean)) / np.float(self.T2_std) + + + + img = np.float32(img) + + data_dict = {"data": img, "seg": label} + + + return data_dict + + + + +def CreateTrainValTestSplit(HistoFile_path, num_splits, num_val_folds, num_test_folds, seed): + + np.random.seed(seed) + + HistoFile = pd.read_csv(HistoFile_path) + + # calculate random split assignments of the subjects + IDs = HistoFile.Master_ID.values # Patient no. + unique_IDs = np.unique(IDs) + + num_subjects = len(unique_IDs) + splits = -1 * np.ones(num_subjects) + s_per_split = num_subjects // num_splits + # assign an equivalent # subjects to the splits + assign = np.random.choice(range(num_subjects), size=(num_splits, s_per_split), replace=False) + for split in range(num_splits): + for subj in range(s_per_split): + splits[assign[split, subj]] = split + + # assign missing subjects + ixs = np.where(splits == -1) + splits[ixs] = np.random.randint(0, high=num_splits, size=len(ixs[0])) + + train_splits = [s for s in range(num_splits)] + subjects = [] + for s in train_splits: + split_ixs = np.where(splits == s) + + split_subjs = unique_IDs[split_ixs] + subjects.append(split_subjs) + + train_folds = [f for f in range(num_splits - num_val_folds - num_test_folds)] + train_ixs_lists = [subjects[fold] for fold in train_folds] + train_ixs = [ix for ixs in train_ixs_lists for ix in ixs] # flatten the list of lists + + val_folds = [f for f in range(num_splits - num_test_folds) if f not in train_folds] + val_ixs_lists = [subjects[fold] for fold in val_folds] + val_ixs = [ix for ixs in val_ixs_lists for ix in ixs] # flatten the list of lists + + test_folds = [f for f in range(num_splits) if f not in val_folds + train_folds] + test_ixs_lists = [subjects[fold] for fold in test_folds] + test_ixs = [ix for ixs in test_ixs_lists for ix in ixs] # flatten the list of lists + + + return train_ixs, val_ixs, test_ixs + + +def get_class_frequencies(Data, train_idx, patch_size): + + Tumor_frequencie_ADC = 0 + Prostate_frequencie_ADC = 0 + Background_frequencie_ADC = 0 + + Tumor_frequencie_T2 = 0 + Prostate_frequencie_T2 = 0 + Background_frequencie_T2 = 0 + + T2_mean = 0 + T2_std = 0 + ADC_mean = 0 + ADC_std = 0 + BVAL_mean = 0 + BVAL_std = 0 + + for i in range(len(train_idx)): + idx = train_idx[i] + Data_class = Data[idx]['label'] + Data_image = Data[idx]['image'] + + center_crop_dimensions = ((patch_size[0], patch_size[1])) + + Data_class_Label = center_crop(Data_class, center_crop_dimensions) + Data_class_Label = Data_class_Label[0] + + Data_image_cropped = center_crop(Data_image, center_crop_dimensions) + Data_image_cropped = Data_image_cropped[0] + + Data_image_cropped = np.nan_to_num(Data_image_cropped) + Data_image_cropped = cut_off_outliers(Data_image_cropped, percentile_lower=0.2, percentile_upper=99.8, per_channel=True) + + T2_mean += np.mean(Data_image_cropped[2,:,:,:]) + T2_std += np.std(Data_image_cropped[2,:,:,:]) + ADC_mean += np.mean(Data_image_cropped[0,:,:,:]) + ADC_std += np.std(Data_image_cropped[0,:,:,:]) + BVAL_mean += np.mean(Data_image_cropped[1,:,:,:]) + BVAL_std += np.std(Data_image_cropped[1,:,:,:]) + + + Tumor_frequencie_ADC += np.sum(Data_class_Label[0, :, :, :] == 2) + Prostate_frequencie_ADC += np.sum(Data_class_Label[0, :, :, :] == 1) + Background_frequencie_ADC += np.sum(Data_class_Label[0, :, :, :] == 0) + + Tumor_frequencie_T2 += np.sum(Data_class_Label[1, :, :, :] == 2) + Prostate_frequencie_T2 += np.sum(Data_class_Label[1, :, :, :] == 1) + Background_frequencie_T2 += np.sum(Data_class_Label[1, :, :, :] == 0) + + T2_mean = T2_mean / np.float(len(train_idx)) + T2_std = T2_std / np.float(len(train_idx)) + ADC_mean = ADC_mean / np.float(len(train_idx)) + ADC_std = ADC_std / np.float(len(train_idx)) + BVAL_mean = BVAL_mean / np.float(len(train_idx)) + BVAL_std = BVAL_std / np.float(len(train_idx)) + + return Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC,\ + Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2, ADC_mean, ADC_std, BVAL_mean, \ + BVAL_std, T2_mean, T2_std + + +def get_oversampling(Data, train_idx, Batch_Size, patch_size): + + IDs = sorted(train_idx) + + Total_Patients = len(IDs) + + print 'Total_Patients', Total_Patients + + Tumor_slices = 0 + Prostate_slices = 0 + Pos_Patient = 0 + Non_Tumor_Slices = 0 + Non_Prostate_Slices = 0 + + for i in range(len(IDs)): + idx = IDs[i] + print(idx) + z_Dim = (Data[idx]['label']).shape + Data_over = Data[idx]['label'] + + center_crop_dimensions = ((patch_size[0], patch_size[1])) + Data_over_cropped = center_crop(Data_over, center_crop_dimensions) + Data_over_cropped = Data_over_cropped[0] + + for f in range(z_Dim[3]): + if np.sum(Data_over_cropped[0,:,:,f] == 2) >= 1: + Tumor_slices += 1 + else: + Non_Tumor_Slices += 1 + + if np.sum(Data_over_cropped[0,:,:,f] == 1) >= 1: + Prostate_slices += 1 + else: + Non_Prostate_Slices += 1 + + if np.sum(Data_over_cropped == 2) >= 1: + Pos_Patient += 1 + + + print 'Tumor', Tumor_slices + print 'Non_Tumor', Non_Tumor_Slices + print 'Pos_Patients', Pos_Patient + + print 'Prostate', Prostate_slices + print 'Non_Prostate_Slices', Non_Prostate_Slices + + + + print(Tumor_slices, Non_Tumor_Slices) + + + Probability_Pos_Patient = Pos_Patient / np.float(Total_Patients) + + print 'Probability_Pos_Patient', Probability_Pos_Patient + + Slices_total = Tumor_slices + Non_Tumor_Slices + + Natural_probability_tu_slice = Tumor_slices/ np.float(Slices_total) + + Natural_probability_Prostate_slice = Prostate_slices / np.float(Slices_total) + + print 'Natural_probability_tu_slice', Natural_probability_tu_slice + print 'Natural_probability_PRO_slice', Natural_probability_Prostate_slice + + Oversampling_Factor = (Natural_probability_tu_slice ** (1/np.float(Batch_Size * Probability_Pos_Patient))) + + print 'oversampling_factor', Oversampling_Factor + + return (Oversampling_Factor), Slices_total, Natural_probability_tu_slice, Natural_probability_Prostate_slice + + + + +def split_training(IDs, len_val, cv, cv_runs): + + runs = cv_runs / 4 + for s in range(runs): + if cv in range(s * 4, s * 4 + 4): + count = s + print count + + np.random.seed(42 + count) + + IDx = np.array(IDs) + np.random.shuffle(IDx) + + factor = cv - count * 4 + + print factor + lower_limit = len_val * factor + print(lower_limit) + + upper_limit = len_val * (factor + 1) + + if factor == 3: + upper_limit = len(IDx) + print(upper_limit) + + val_idx = IDx[lower_limit:upper_limit] + train_idx = np.setdiff1d(IDs, val_idx) + + return train_idx, val_idx + + +def save(Image, OutputFilePath, Mask = False): + + if Mask == True: + sitk.WriteImage(sitk.Cast(Image, sitk.sitkUInt8), OutputFilePath) + else: + sitk.WriteImage(sitk.Cast(Image, sitk.sitkFloat32), OutputFilePath) + + finalImg = sitk.ReadImage(OutputFilePath) + finalImg = sitk.GetArrayFromImage(finalImg) + finalImg = finalImg.swapaxes(0, 2) + finalImg_data, finalImg_options = nrrd.read(OutputFilePath) + finalImg_options['encoding'] = 'gzip' + nrrd.write(OutputFilePath, finalImg, header=finalImg_options) \ No newline at end of file