diff --git a/UNet_main_test.py b/UNet_main_test.py
new file mode 100755
index 0000000..befa7a4
--- /dev/null
+++ b/UNet_main_test.py
@@ -0,0 +1,172 @@
+# UNet_main_test.py : Main Code for testing the UNet accompanying publication "Classification of prostate cancer on MRI: Deep learning vs. clinical PI-RADS assessment", Patrick Schelb, Simon Kohl, Jan Philipp Radtke MD, Manuel Wiesenfarth PhD, Philipp Kickingereder MD, Sebastian Bickelhaupt, Tristan Anselm Kuder PhD, Albrecht Stenzinger, Markus Hohenfellner MD, Heinz-Peter Schlemmer MD, PhD, Klaus H. Maier-Hein PhD, David Bonekamp MD, Radiology, [manuscript accepted for publication]
+# Copyright (C) 2019 German Cancer Research Center (DKFZ)
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+# contact: David Bonekamp, MD, d.bonekamp@dkfz-heidelberg.de
+
+__author__ = "German Cancer Research Center (DKFZ)"
+
+
+import os
+import numpy as np
+import pandas as pd
+import torch
+import argparse
+from batchgenerators.transforms.abstract_transforms import Compose
+from batchgenerators.transforms.spatial_transforms import SpatialTransform, Mirror
+from batchgenerators.transforms.noise_transforms import RicianNoiseTransform
+from batchgenerators.transforms.resample_transforms import ResampleTransform
+from batchgenerators.transforms.abstract_transforms import RndTransform
+from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
+from batchgenerators.transforms.sample_normalization_transforms import CutOffOutliersTransform
+from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, BrightnessTransform, ContrastAugmentationTransform, GammaTransform
+from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform
+from UNet_net import UNetPytorch
+from UNet_utils import BatchGenerator, CreateTrainValTestSplit, validate, get_class_frequencies, CrossEntropyLoss2d
+
+
+####################################################################################################################
+### Settings
+####################################################################################################################
+
+parser = argparse.ArgumentParser(description='RADIOLOGY-2019 Prostate Lesion Segmentation ensemble testing script')
+
+parser.add_argument('--input_file', default='', type=str, help='name of input file')
+parser.add_argument('--ensemble_path', default='', type=str, help='path to pretrained CNN ensemble')
+parser.add_argument('--checkpoint_name', default='checkpoint_UNetPytorch.pth.tar', type=str, help='name of ensemble checkpoint files')
+parser.add_argument('--output_path', default='', type=str, help='name of output folder')
+parser.add_argument('--num_splits', default=10, type=int, help='split dataset in 10 folds')
+parser.add_argument('--num_val_folds', default=2, type=int, help='leave x fold out for validation')
+parser.add_argument('--num_test_folds', default=2, type=int, help='leave x fold out for testing')
+parser.add_argument('--seed', default=42, type=int, help='seed')
+parser.add_argument('--patch_size', default=(160, 160), type=tuple, help='define patch size')
+
+
+def main():
+
+ # assign global args
+ global args
+ args = parser.parse_args()
+
+
+ # make a folder for the experiment
+ general_folder_name = args.output_path
+ try:
+ os.mkdir(general_folder_name)
+ except OSError:
+ pass
+
+ # create train, test split, return the indices, patients in test_split wont be seen during whole training
+ train_idx, val_idx, test_idx = CreateTrainValTestSplit(HistoFile_path=args.input_file, num_splits=args.num_splits, num_test_folds=args.num_test_folds,
+ num_val_folds=args.num_val_folds, seed=args.seed)
+
+ print('size of training set {}'.format(len(train_idx)))
+ print('size of validation set {}'.format(len(val_idx)))
+ print('size of test set {}'.format(len(test_idx)))
+
+ train_idx = train_idx + val_idx
+
+ # data loading
+ Data = ProstataData(args.input_file) #For details on this class see README
+
+
+ # get class frequencies
+ print('calculating class frequencie')
+
+ Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC, \
+ Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2 \
+ , ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std \
+ = get_class_frequencies(Data, train_idx, patch_size=args.patch_size)
+
+ print ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std
+
+ print('ADC', Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC)
+ print('T2', Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2)
+
+ all_ADC = np.float(Background_frequencie_ADC + Prostate_frequencie_ADC + Tumor_frequencie_ADC)
+ all_T2 = np.float(Background_frequencie_T2 + Prostate_frequencie_T2 + Tumor_frequencie_T2)
+
+ print all_ADC
+ print all_T2
+
+ W1_ADC = 1 / (Background_frequencie_ADC / all_ADC) ** 0.25
+ W2_ADC = 1 / (Prostate_frequencie_ADC / all_ADC) ** 0.25
+ W3_ADC = 1 / (Tumor_frequencie_ADC / all_ADC) ** 0.25
+
+ Wa_ADC = W1_ADC / (W1_ADC + W2_ADC + W3_ADC)
+ Wb_ADC = W2_ADC / (W1_ADC + W2_ADC + W3_ADC)
+ Wc_ADC = W3_ADC / (W1_ADC + W2_ADC + W3_ADC)
+
+ print 'Weights ADC', Wa_ADC, Wb_ADC, Wc_ADC
+
+ weight_ADC = (Wa_ADC, Wb_ADC, Wc_ADC)
+
+ W1_T2 = 1 / (Background_frequencie_T2 / all_T2) ** 0.25
+ W2_T2 = 1 / (Prostate_frequencie_T2 / all_T2) ** 0.25
+ W3_T2 = 1 / (Tumor_frequencie_T2 / all_T2) ** 0.25
+
+ Wa_T2 = W1_T2 / (W1_T2 + W2_T2 + W3_T2)
+ Wb_T2 = W2_T2 / (W1_T2 + W2_T2 + W3_T2)
+ Wc_T2 = W3_T2 / (W1_T2 + W2_T2 + W3_T2)
+
+ print 'Weights T2', Wa_T2, Wb_T2, Wc_T2
+
+ weight_T2 = (Wa_T2, Wb_T2, Wc_T2)
+
+ criterion_ADC = CrossEntropyLoss2d(weight=torch.FloatTensor(weight_ADC)).cuda()
+ criterion_T2 = CrossEntropyLoss2d(weight=torch.FloatTensor(weight_T2)).cuda()
+
+ Center_Crop = CenterCropTransform(args.patch_size)
+ count = 0
+
+ for folder, subfolders, files in sorted(os.walk(args.ensemble_path)):
+ for file in files:
+ if file.startswith(args.checkpoint_name):
+ epoch = 0
+
+ # define model
+ Net = UNetPytorch(in_shape=(3, args.patch_size[0], args.patch_size[1]))
+ model = Net.cuda()
+
+ folder_name_for_single_nets = general_folder_name + '/Net_{}'.format(count)
+
+ try:
+ os.mkdir(folder_name_for_single_nets)
+ except OSError:
+ pass
+
+ model_path = os.path.join(os.path.abspath(folder), file)
+ checkpoint = torch.load(model_path)
+ model.load_state_dict(checkpoint['state_dict'])
+ test_loader = BatchGenerator(Data, BATCH_SIZE=0, split_idx=test_idx, seed=args.seed,
+ ProbabilityTumorSlices=None, epoch=epoch, test=True,
+ ADC_mean=ADC_mean, ADC_std=ADC_std, BVAL_mean=BVAL_mean, BVAL_std=
+ BVAL_std, T2_mean=T2_mean, T2_std=T2_std
+ )
+
+ test_loss = validate(test_loader, model,
+ folder_name=folder_name_for_single_nets, criterion_ADC=criterion_ADC, criterion_T2=criterion_T2,
+ split_ixs=test_idx, epoch=epoch, workers=1, Center_Crop=Center_Crop, seed=args.seed, test=True)
+
+ TestCSV = pd.DataFrame({'test_loss': [test_loss]})
+ TestCSV.to_csv(folder_name_for_single_nets + '/TestCSV.csv')
+ count += 1
+
+if __name__ == '__main__':
+ main()
+
+
+
+
diff --git a/UNet_main_train.py b/UNet_main_train.py
new file mode 100755
index 0000000..7682b3e
--- /dev/null
+++ b/UNet_main_train.py
@@ -0,0 +1,304 @@
+# UNet_main_train.py : Main Code for training the UNet accompanying publication "Classification of prostate cancer on MRI: Deep learning vs. clinical PI-RADS assessment", Patrick Schelb, Simon Kohl, Jan Philipp Radtke MD, Manuel Wiesenfarth PhD, Philipp Kickingereder MD, Sebastian Bickelhaupt, Tristan Anselm Kuder PhD, Albrecht Stenzinger, Markus Hohenfellner MD, Heinz-Peter Schlemmer MD, PhD, Klaus H. Maier-Hein PhD, David Bonekamp MD, Radiology, [manuscript accepted for publication]
+# Copyright (C) 2019 German Cancer Research Center (DKFZ)
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+# contact: David Bonekamp, MD, d.bonekamp@dkfz-heidelberg.de
+
+__author__ = "German Cancer Research Center (DKFZ)"
+
+
+import numpy as np
+import pandas as pd
+import os
+import torch
+import argparse
+from batchgenerators.transforms.abstract_transforms import Compose
+from batchgenerators.transforms.spatial_transforms import SpatialTransform, Mirror
+from batchgenerators.transforms.noise_transforms import RicianNoiseTransform
+from batchgenerators.transforms.resample_transforms import ResampleTransform
+from batchgenerators.transforms.abstract_transforms import RndTransform
+from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
+from batchgenerators.transforms.sample_normalization_transforms import CutOffOutliersTransform
+from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, BrightnessTransform, ContrastAugmentationTransform, GammaTransform
+from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform
+from UNet_net import UNetPytorch
+from UNet_utils import train, save_checkpoint, adjust_learning_rate\
+ ,BatchGenerator, CreateTrainValTestSplit, CrossEntropyLoss2d, get_class_frequencies, clear_image_data, \
+ split_training, validate, \
+ get_oversampling
+
+
+####################################################################################################################
+### Settings
+####################################################################################################################
+
+parser = argparse.ArgumentParser(description='RADIOLOGY-2019 Prostate Lesion Segmentation training script')
+
+parser.add_argument('--input_file', default='', type=str, help='name of input file')
+parser.add_argument('--output_path', default='', type=str, help='name of output folder')
+parser.add_argument('--arch',default='UNet')
+parser.add_argument('--workers', default=8, type=int, metavar='N', help='number of data loading workers')
+parser.add_argument('--epochs', default=80, type=int, metavar='N', help='number of total epochs to run')
+parser.add_argument('--b', default=32, type=int, help='mini-batch size')
+parser.add_argument('--lr', default=0.00015, type=float, help='initial learning rate')
+parser.add_argument('--weight_decay', default=1e-5, type=float, help='weight decay')
+parser.add_argument('--num_splits', default=10, type=int, help='split dataset in 10 folds')
+parser.add_argument('--num_val_folds', default=2, type=int, help='leave x fold out for validation')
+parser.add_argument('--num_test_folds', default=2, type=int, help='leave x fold out for testing')
+parser.add_argument('--seed', default=42, type=int, help='seed')
+parser.add_argument('--patch_size', default=(160, 160), type=tuple, help='define patch size')
+parser.add_argument('--p', default=1, type=float, help='probability for spatial transform')
+parser.add_argument('--cv_number', default=16, type=int, help='number of cross validation loops')
+parser.add_argument('--cv_start', default=0, type=int, help='')
+parser.add_argument('--cv_end', default=16, type=int, help='')
+
+
+def main():
+
+ # assign global args
+ global args
+ args = parser.parse_args()
+
+
+ # make a folder for the experiment
+ general_folder_name = args.output_path
+ try:
+ os.mkdir(general_folder_name)
+ except OSError:
+ pass
+
+
+ # create train, test split, return the indices, patients in test_split wont be seen during whole training
+ train_idx, val_idx, test_idx = CreateTrainValTestSplit(HistoFile_path=args.input_file, num_splits=args.num_splits, num_test_folds=args.num_test_folds,
+ num_val_folds=args.num_val_folds, seed=args.seed)
+
+ IDs = train_idx + val_idx
+
+ print('size of training set {}'.format(len(train_idx)))
+ print('size of validation set {}'.format(len(val_idx)))
+ print('size of test set {}'.format(len(test_idx)))
+
+
+ # data loading
+ Data = ProstataData(args.input_file) #For details on this class see README
+
+ # train and validate
+ for cv in range(args.cv_start, args.cv_end):
+ best_epoch = 0
+ train_loss = []
+ val_loss = []
+
+ # define patients for training and validation
+ train_idx, val_idx = split_training(IDs, len_val=62, cv=cv, cv_runs=args.cv_number)
+
+ oversampling_factor, Slices_total, Natural_probability_tu_slice, Natural_probability_PRO_slice = get_oversampling(Data, train_idx=sorted(train_idx), Batch_Size=args.b, patch_size=args.patch_size)
+
+ training_batches = Slices_total / args.b
+
+ lr = args.lr
+ base_lr = args.lr
+ args.seed += 1
+
+ print('train_idx', train_idx, len(train_idx))
+ print('val_idx', val_idx, len(val_idx))
+
+ # get class frequencies
+ print('calculating class frequencie')
+
+ Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC,\
+ Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2\
+ , ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std \
+ = get_class_frequencies(Data, train_idx, patch_size=args.patch_size)
+
+ print ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std
+
+ print('ADC', Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC)
+ print('T2', Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2)
+
+ all_ADC = np.float(Background_frequencie_ADC + Prostate_frequencie_ADC + Tumor_frequencie_ADC)
+ all_T2 = np.float(Background_frequencie_T2 + Prostate_frequencie_T2 + Tumor_frequencie_T2)
+
+ print all_ADC
+ print all_T2
+
+ W1_ADC = 1 / (Background_frequencie_ADC / all_ADC) ** 0.25
+ W2_ADC = 1 / (Prostate_frequencie_ADC / all_ADC) ** 0.25
+ W3_ADC = 1 / (Tumor_frequencie_ADC / all_ADC) ** 0.25
+
+ Wa_ADC = W1_ADC / (W1_ADC + W2_ADC + W3_ADC)
+ Wb_ADC = W2_ADC / (W1_ADC + W2_ADC + W3_ADC)
+ Wc_ADC = W3_ADC / (W1_ADC + W2_ADC + W3_ADC)
+
+ print 'Weights ADC', Wa_ADC, Wb_ADC, Wc_ADC
+
+ weight_ADC = (Wa_ADC, Wb_ADC, Wc_ADC)
+
+ W1_T2 = 1 / (Background_frequencie_T2 / all_T2) ** 0.25
+ W2_T2 = 1 / (Prostate_frequencie_T2 / all_T2) ** 0.25
+ W3_T2 = 1 / (Tumor_frequencie_T2 / all_T2) ** 0.25
+
+ Wa_T2 = W1_T2 / (W1_T2 + W2_T2 + W3_T2)
+ Wb_T2 = W2_T2 / (W1_T2 + W2_T2 + W3_T2)
+ Wc_T2 = W3_T2 / (W1_T2 + W2_T2 + W3_T2)
+
+ print 'Weights T2', Wa_T2, Wb_T2, Wc_T2
+
+ weight_T2 = (Wa_T2, Wb_T2, Wc_T2)
+
+ # define model
+ Net = UNetPytorch(in_shape=(3, args.patch_size[0], args.patch_size[1]))
+ Net_Name = 'UNetPytorch'
+ model = Net.cuda()
+
+ # model parameter
+ optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=args.weight_decay)
+ criterion_ADC = CrossEntropyLoss2d(weight=torch.FloatTensor(weight_ADC)).cuda()
+ criterion_T2 = CrossEntropyLoss2d(weight=torch.FloatTensor(weight_T2)).cuda()
+
+
+ # new folder name for cv
+ folder_name = general_folder_name + '/CV_{}'.format(cv)
+ try:
+ os.mkdir(folder_name)
+ except OSError:
+ pass
+
+ checkpoint_file = folder_name + '/checkpoint_' + '{}.pth.tar'.format(Net_Name)
+
+ # augmentation
+ for epoch in range(args.epochs):
+ torch.manual_seed(args.seed + epoch + cv)
+ np.random.seed(epoch + cv)
+ np.random.shuffle(train_idx)
+
+ if epoch == 0:
+ my_transforms = []
+ spatial_transform = SpatialTransform(args.patch_size, np.array(args.patch_size) // 2,
+ do_elastic_deform=True, alpha=(100., 450.),
+ sigma=(13., 17.),
+ do_rotation=True, angle_z=(-np.pi / 2., np.pi / 2.),
+ do_scale=True, scale=(0.75, 1.25),
+ border_mode_data='constant', border_cval_data=0,
+ order_data=3,
+ random_crop=True)
+ resample_transform = ResampleTransform(zoom_range=(0.7, 1.3))
+ brightness_transform = BrightnessTransform(0.0, 0.1, True)
+ my_transforms.append(resample_transform)
+ my_transforms.append(ContrastAugmentationTransform((0.75, 1.25), True))
+ my_transforms.append(brightness_transform)
+ my_transforms.append(Mirror((2, 3)))
+ all_transforms = Compose(my_transforms)
+ sometimes_spatial_transforms = RndTransform(spatial_transform, prob=args.p,
+ alternative_transform=CenterCropTransform(
+ args.patch_size))
+ sometimes_other_transforms = RndTransform(all_transforms, prob=1.0)
+ final_transform = Compose([sometimes_spatial_transforms, sometimes_other_transforms])
+ Center_Crop = CenterCropTransform(args.patch_size)
+
+ if epoch == 30:
+ my_transforms = []
+ spatial_transform = SpatialTransform(args.patch_size, np.array(args.patch_size) // 2,
+ do_elastic_deform=True, alpha=(0., 250.),
+ sigma=(11., 14.),
+ do_rotation=True, angle_z=(-np.pi / 2., np.pi / 2.),
+ do_scale=True, scale=(0.85, 1.15),
+ border_mode_data='constant', border_cval_data=0,
+ order_data=3,
+ random_crop=True)
+ resample_transform = ResampleTransform(zoom_range=(0.8, 1.2))
+ brightness_transform = BrightnessTransform(0.0, 0.1, True)
+ my_transforms.append(resample_transform)
+ my_transforms.append(ContrastAugmentationTransform((0.85, 1.15), True))
+ my_transforms.append(brightness_transform)
+ all_transforms = Compose(my_transforms)
+ sometimes_spatial_transforms = RndTransform(spatial_transform, prob=args.p,
+ alternative_transform=CenterCropTransform(
+ args.patch_size))
+ sometimes_other_transforms = RndTransform(all_transforms, prob=1.0)
+ final_transform = Compose([sometimes_spatial_transforms, sometimes_other_transforms])
+ Center_Crop = CenterCropTransform(args.patch_size)
+
+ if epoch == 50:
+ my_transforms = []
+ spatial_transform = SpatialTransform(args.patch_size, np.array(args.patch_size) // 2,
+ do_elastic_deform=True, alpha=(0., 150.),
+ sigma=(10., 12.),
+ do_rotation=True, angle_z=(-np.pi / 2., np.pi / 2.),
+ do_scale=True, scale=(0.85, 1.15),
+ border_mode_data='constant', border_cval_data=0,
+ order_data=3,
+ random_crop=False)
+ resample_transform = ResampleTransform(zoom_range=(0.9, 1.1))
+ brightness_transform = BrightnessTransform(0.0, 0.1, True)
+ my_transforms.append(resample_transform)
+ my_transforms.append(ContrastAugmentationTransform((0.95, 1.05), True))
+ my_transforms.append(brightness_transform)
+ all_transforms = Compose(my_transforms)
+ sometimes_spatial_transforms = RndTransform(spatial_transform, prob=args.p,
+ alternative_transform=CenterCropTransform(
+ args.patch_size))
+ sometimes_other_transforms = RndTransform(all_transforms, prob=1.0)
+ final_transform = Compose([sometimes_spatial_transforms, sometimes_other_transforms])
+ Center_Crop = CenterCropTransform(args.patch_size)
+
+
+ train_loader = BatchGenerator(Data, BATCH_SIZE=args.b, split_idx=train_idx, seed=args.seed,
+ ProbabilityTumorSlices=oversampling_factor, epoch=epoch,
+ ADC_mean=ADC_mean, ADC_std=ADC_std, BVAL_mean=BVAL_mean, BVAL_std=
+ BVAL_std, T2_mean=T2_mean, T2_std=T2_std)
+
+ val_loader = BatchGenerator(Data, BATCH_SIZE=0, split_idx=val_idx, seed=args.seed,
+ ProbabilityTumorSlices=None, epoch=epoch, test=True,
+ ADC_mean=ADC_mean, ADC_std=ADC_std, BVAL_mean=BVAL_mean, BVAL_std=
+ BVAL_std, T2_mean=T2_mean, T2_std=T2_std
+ )
+
+
+ #train on training set
+ train_losses = train(train_loader=train_loader, model=model, optimizer=optimizer,
+ criterion_ADC=criterion_ADC, criterion_T2=criterion_T2,
+ final_transform=final_transform, workers=args.workers, seed=args.seed,
+ training_batches=training_batches)
+ train_loss.append(train_losses)
+
+
+ # evaluate on validation set
+ val_losses = validate(val_loader=val_loader, model=model, folder_name=folder_name,
+ criterion_ADC =criterion_ADC, criterion_T2=criterion_T2, split_ixs=val_idx,
+ epoch=epoch, workers=1, Center_Crop=Center_Crop, seed=args.seed)
+ val_loss.append(val_losses)
+
+
+ # write TrainingsCSV to folder name
+ TrainingsCSV = pd.DataFrame({'train_loss': train_loss, 'val_loss': val_loss})
+ TrainingsCSV.to_csv(folder_name + '/TrainingsCSV.csv')
+
+ if val_losses <= min(val_loss):
+ best_epoch = epoch
+ print 'best epoch', epoch
+ save_checkpoint({'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(),
+ 'optimizer': optimizer.state_dict()
+ }, filename=checkpoint_file)
+
+ optimizer, lr = adjust_learning_rate(optimizer, base_lr, epoch)
+
+ # delete all output except best epoch
+ clear_image_data(folder_name, best_epoch, epoch)
+
+
+if __name__ == '__main__':
+ main()
+
+
diff --git a/UNet_net.py b/UNet_net.py
new file mode 100755
index 0000000..beb4379
--- /dev/null
+++ b/UNet_net.py
@@ -0,0 +1,126 @@
+# UNet_net.py : CNN architecture accompanying publication "Classification of prostate cancer on MRI: Deep learning vs. clinical PI-RADS assessment", Patrick Schelb, Simon Kohl, Jan Philipp Radtke MD, Manuel Wiesenfarth PhD, Philipp Kickingereder MD, Sebastian Bickelhaupt, Tristan Anselm Kuder PhD, Albrecht Stenzinger, Markus Hohenfellner MD, Heinz-Peter Schlemmer MD, PhD, Klaus H. Maier-Hein PhD, David Bonekamp MD, Radiology, [manuscript accepted for publication]
+# Copyright (C) 2019 German Cancer Research Center (DKFZ)
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+# contact: David Bonekamp, MD, d.bonekamp@dkfz-heidelberg.de
+
+__author__ = "German Cancer Research Center (DKFZ)"
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# modified Unet implementation from https://www.kaggle.com/c/carvana-image-masking-challenge/discussion/37208
+
+BN_EPS = 1e-5
+
+class UNetPytorch (nn.Module):
+ def __init__(self, in_shape):
+ super(UNetPytorch, self).__init__()
+ C, H, W = in_shape
+
+ self.down2 = StackEncoder(C, 64, kernel_size=3)
+ self.down3 = StackEncoder(64, 128, kernel_size=3)
+ self.down4 = StackEncoder(128, 256, kernel_size=3)
+ self.down5 = StackEncoder(256, 512, kernel_size=3)
+ self.down6 = StackEncoder(512, 1024, kernel_size=3)
+
+ self.center = nn.Sequential(
+ ConvRelu2d(1024, 1024, kernel_size=3, padding=1, stride=1),
+ ConvRelu2d(1024, 1024, kernel_size=3, padding=1, stride=1),
+ ConvRelu2d(1024, 1024, kernel_size=3, padding=1, stride=1))
+
+ self.up6 = StackDecoder(1024, 1024, 512, kernel_size=3)
+ self.up5 = StackDecoder(512, 512, 256, kernel_size=3)
+ self.up4 = StackDecoder(256, 256, 128, kernel_size=3)
+ self.up3 = StackDecoder(128, 128, 64, kernel_size=3)
+ self.up2 = StackDecoder(64, 64, 32, kernel_size=3)
+
+ self.classify = nn.Conv2d(32, 3, kernel_size=1, padding=0, stride=1, bias=True)
+
+
+ def forward(self, input):
+ out = input
+
+ down2, out = self.down2(out)
+ down3, out = self.down3(out)
+ down4, out = self.down4(out)
+ down5, out = self.down5(out)
+ down6, out = self.down6(out)
+ pass
+
+ out = self.center(out)
+
+ out = self.up6(down6, out)
+ out = self.up5(down5, out)
+ out = self.up4(down4, out)
+ out = self.up3(down3, out)
+ out = self.up2(down2, out)
+
+ out = self.classify(out)
+
+ return out
+
+
+class ConvRelu2d(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, stride=1, groups=1, is_relu=True, is_bn=True):
+ super(ConvRelu2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, groups=groups, bias=False)
+ self.relu = nn.ReLU(inplace=True)
+ self.bn = nn.BatchNorm2d(out_channels, eps=BN_EPS)
+ if is_relu is False: self.relu = None
+ if is_bn is False: self.bn = None
+
+ def forward(self, input):
+ convoluted = self.conv(input)
+ if self.relu is not None: convoluted = self.relu(convoluted)
+ if self.bn is not None: convoluted = self.bn(convoluted)
+ return convoluted
+
+class StackEncoder (nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=3):
+ super(StackEncoder, self).__init__()
+ padding=(kernel_size-1)//2
+
+ self.encode = nn.Sequential(
+ ConvRelu2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
+ ConvRelu2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
+ ConvRelu2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1))
+
+ def forward(self, input):
+ encoded = self.encode(input)
+ max_pooled = F.max_pool2d(encoded, kernel_size=2, stride=2)
+
+ return encoded, max_pooled
+
+
+class StackDecoder (nn.Module):
+ def __init__(self, in_channels_down, in_channels, out_channels, kernel_size=3):
+ super(StackDecoder, self).__init__()
+ padding=(kernel_size-1)//2
+
+ self.decode = nn.Sequential(
+ ConvRelu2d(in_channels_down + in_channels, out_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
+ ConvRelu2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
+ ConvRelu2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1))
+
+ def forward(self, down_input, input):
+ N, C, H, W = down_input.size()
+ upsampled = F.interpolate(input, size=(H,W) ,mode='bilinear', align_corners=True)
+ upsampled = torch.cat([upsampled, down_input], 1)
+ decoded = self.decode(upsampled)
+ return decoded
+
diff --git a/UNet_utils.py b/UNet_utils.py
new file mode 100755
index 0000000..12ace95
--- /dev/null
+++ b/UNet_utils.py
@@ -0,0 +1,650 @@
+# UNet_utils.py : Utilities Code (function definitions) accompanying publication "Classification of prostate cancer on MRI: Deep learning vs. clinical PI-RADS assessment", Patrick Schelb, Simon Kohl, Jan Philipp Radtke MD, Manuel Wiesenfarth PhD, Philipp Kickingereder MD, Sebastian Bickelhaupt, Tristan Anselm Kuder PhD, Albrecht Stenzinger, Markus Hohenfellner MD, Heinz-Peter Schlemmer MD, PhD, Klaus H. Maier-Hein PhD, David Bonekamp MD, Radiology, [manuscript accepted for publication]
+# Copyright (C) 2019 German Cancer Research Center (DKFZ)
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+# contact: David Bonekamp, MD, d.bonekamp@dkfz-heidelberg.de
+
+__author__ = "German Cancer Research Center (DKFZ)"
+
+import pandas as pd
+import os
+import SimpleITK as sitk
+from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
+from batchgenerators.dataloading.data_loader import DataLoaderBase
+from builtins import object
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import torch
+import shutil
+import nrrd
+from batchgenerators.augmentations.crop_and_pad_augmentations import center_crop
+from batchgenerators.transforms.sample_normalization_transforms import CutOffOutliersTransform
+from batchgenerators.augmentations.normalizations import cut_off_outliers
+
+
+class CrossEntropyLoss2d(nn.Module):
+ def __init__(self, weight=None, size_average=True, reduce=True):
+ super(CrossEntropyLoss2d, self).__init__()
+ self.nll_loss = nn.NLLLoss2d(weight=weight, size_average=size_average, reduce=reduce)
+
+ def forward(self, inputs, targets):
+ return self.nll_loss(F.log_softmax(inputs), targets)
+
+
+def ToTensor(batch):
+
+ image, label = batch['data'], batch['seg']
+
+ data = torch.from_numpy(image[:,:,:,:])
+ seg = torch.from_numpy(label[:,0,:,:])
+ seg_T2 = torch.from_numpy(label[:,1,:,:])
+
+ return {'data': data,
+ 'seg': seg,
+ 'seg_T2': seg_T2}
+
+
+
+def train(train_loader, model, optimizer, criterion_ADC, criterion_T2, final_transform, workers, seed,
+ training_batches):
+
+ train_losses = AverageMeter()
+ np.random.seed(seed)
+ seeds = np.random.choice(seed, workers, False, None)
+ model.train()
+ multithreaded_generator = MultiThreadedAugmenter(train_loader, final_transform, workers, 2, seeds=seeds)
+ torch.cuda.empty_cache()
+
+ for i in range(training_batches):
+ print('Batch: [{0}/{1}]'.format(i +1, training_batches))
+ batch = multithreaded_generator.next()
+ TensorBatch = ToTensor(batch)
+ target = TensorBatch['seg'].cuda()
+ target_T2 = TensorBatch['seg_T2'].cuda()
+ input_var = torch.autograd.Variable(TensorBatch['data'], requires_grad=True).cuda(async=True)
+ input_var = input_var.float()
+ target_var = torch.autograd.Variable(target)
+ target_var = target_var.long()
+ target_var_T2 = torch.autograd.Variable(target_T2)
+ target_var_T2 = target_var_T2.long()
+ optimizer.zero_grad()
+ output = model(input_var)
+ loss_ADC = criterion_ADC(output, target_var)
+ loss_T2 = criterion_T2(output, target_var_T2)
+ loss = (loss_ADC + loss_T2) / 2.
+ loss.backward()
+ optimizer.step()
+ train_losses.update(loss.item())
+ print 'train_loss', loss.item()
+
+ torch.cuda.empty_cache()
+
+ return train_losses.avg
+
+
+
+def validate(val_loader, model, epoch, criterion_ADC, criterion_T2 ,split_ixs, Center_Crop, workers, seed, folder_name,
+ test=False):
+
+ val_losses = AverageMeter()
+ seeds = np.random.choice(seed, workers, False, None)
+ torch.cuda.empty_cache()
+ model.eval()
+ multithreaded_generator = MultiThreadedAugmenter(val_loader, Center_Crop, workers, 2, seeds=seeds)
+
+ for i in range(len(split_ixs)):
+ patient = split_ixs[i]
+ print 'patient', patient
+ batch = multithreaded_generator.next()
+ TensorBatch = ToTensor(batch)
+ target = TensorBatch['seg'].cuda()
+ target_T2 = TensorBatch['seg_T2'].cuda()
+ input_var = torch.autograd.Variable(TensorBatch['data'], volatile=True).cuda(async=True)
+ input_var = input_var.float()
+ target_var = torch.autograd.Variable(target, volatile=True)
+ target_var = target_var.long()
+ target_var_T2 = torch.autograd.Variable(target_T2, volatile=True)
+ target_var_T2 = target_var_T2.long()
+ output = model(input_var)
+ probs = F.softmax(output)
+ loss_ADC = criterion_ADC(output, target_var)
+ loss_T2 = criterion_T2(output, target_var_T2)
+ loss = (loss_ADC + loss_T2) / 2.
+ val_losses.update(loss.item())
+ if test == False:
+ print 'val_loss', loss.item()
+ else:
+ print 'test_loss', loss.item()
+
+ image = (input_var.data).cpu().numpy()
+ Mprobs = (probs.data).cpu().numpy()
+ fprobs = (probs.data).cpu().numpy()
+ segmentation = (target).cpu().numpy()
+ segmentation_T2 = (target_T2).cpu().numpy()
+ label = np.where(segmentation == 2, 1, 0)
+ label_T2 = np.where(segmentation_T2 == 2, 1, 0)
+ label = np.uint8(label)
+ label_T2 = np.uint8(label_T2)
+ PRO = np.where(segmentation == 1, 1, 0)
+ PRO = np.uint8(PRO)
+ PRO_T2 = np.where(segmentation_T2 == 1, 1, 0)
+ PRO_T2 = np.uint8(PRO_T2)
+
+ fprobs[:, 0, :, :] = fprobs[:, 0, :, :] == np.amax(fprobs, axis=1)
+ fprobs[:, 1, :, :] = fprobs[:, 1, :, :] == np.amax(fprobs, axis=1)
+ fprobs[:, 2, :, :] = fprobs[:, 2, :, :] == np.amax(fprobs, axis=1)
+
+ ProstateOut = fprobs[:, 1, :, :]
+ TumorOut = fprobs[:, 2, :, :]
+
+ probability_map_back = Mprobs[:, 0, :, :]
+ probability_map_pro = Mprobs[:, 1, :, :]
+ probability_map_tu = Mprobs[:, 2, :, :]
+
+ if test == False:
+ try:
+ os.mkdir(folder_name + '/Val_Images')
+ except OSError:
+ pass
+
+ try:
+ os.mkdir(folder_name + '/Val_Images/Epoch_{}'.format(epoch))
+ except OSError:
+ pass
+
+ save_images_to = folder_name + '/Val_Images/Epoch_{}/Patient_{}'.format(epoch, patient)
+ try:
+ os.mkdir(save_images_to)
+ except OSError:
+ pass
+
+ else:
+ try:
+ os.mkdir(folder_name + '/Test_Images')
+ except OSError:
+ pass
+
+ save_images_to = folder_name + '/Test_Images/Patient_{}'.format(patient)
+ try:
+ os.mkdir(save_images_to)
+ except OSError:
+ pass
+
+
+ ADCimage = image[:, 0, :, :]
+ BVALimage = image[:, 1, :, :]
+ T2image = image[:, 2, :, :]
+
+ TumorOut = sitk.GetImageFromArray(np.uint8(TumorOut))
+ ProstateOut = sitk.GetImageFromArray(np.uint8(ProstateOut))
+ ADCimg = sitk.GetImageFromArray(ADCimage)
+ BVALimg = sitk.GetImageFromArray(BVALimage)
+ T2img = sitk.GetImageFromArray(T2image)
+ seg = sitk.GetImageFromArray(label)
+ seg_T2 = sitk.GetImageFromArray(label_T2)
+ pro = sitk.GetImageFromArray(PRO)
+ pro_T2 = sitk.GetImageFromArray(PRO_T2)
+ probsBack = sitk.GetImageFromArray(probability_map_back)
+ probsPRO = sitk.GetImageFromArray(probability_map_pro)
+ probsTU = sitk.GetImageFromArray(probability_map_tu)
+
+ save(TumorOut, save_images_to + '/Tumor_Output.nrrd', Mask=True)
+ save(ProstateOut, save_images_to + '/Prostate_Output.nrrd', Mask=True)
+ save(ADCimg, save_images_to + '/ADCImage.nrrd')
+ save(BVALimg, save_images_to + '/BVALImage.nrrd')
+ save(T2img, save_images_to + '/T2Image.nrrd')
+
+ save(seg, save_images_to + '/Label.nrrd', Mask=True)
+ save(seg_T2, save_images_to + '/Label_T2.nrrd', Mask=True)
+ save(pro, save_images_to + '/Pro_Label.nrrd', Mask=True)
+ save(pro_T2, save_images_to + '/Pro_Label_T2.nrrd', Mask=True)
+ save(probsBack, save_images_to + '/ProbabilityMapBack.nrrd')
+ save(probsTU, save_images_to + '/ProbabilityMapTU.nrrd')
+ save(probsPRO, save_images_to + '/ProbabilityMapPRO.nrrd')
+ torch.cuda.empty_cache()
+
+ return val_losses.avg
+
+
+def clear_image_data(folder_name, best_epoch, epoch):
+ for e in range(epoch+1):
+ if e == best_epoch:
+ print('best epoch')
+ else:
+ path_name = folder_name + '/Val_Images/Epoch_{}/'.format(e)
+ try:
+ shutil.rmtree(path_name)
+ except OSError:
+ pass
+
+
+def save_checkpoint(state, filename):
+ torch.save(state, filename)
+
+
+class AverageMeter(object):
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def adjust_learning_rate(optimizer, arg_lr, epoch):
+
+ lr = np.float32(arg_lr * 0.98 ** epoch)
+
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+ return optimizer, lr
+
+
+def resample(fixed, target_resample_resolution, Mask=False):
+
+ if Mask == True:
+ Interpolator = sitk.sitkNearestNeighbor
+ else:
+ Interpolator = sitk.sitkBSpline
+
+ fixed_spacing = fixed.GetSpacing()
+ fixed_spacing = np.array(fixed_spacing)
+ fixed_size = fixed.GetSize()
+ fixed_size = np.array(fixed_size)
+
+ image_size = fixed_size * (fixed_spacing / target_resample_resolution)
+ image_size = np.around(image_size)
+ image_size = image_size.astype(np.uint32).tolist()
+
+ resample = sitk.ResampleImageFilter()
+ resample.SetOutputSpacing(target_resample_resolution)
+ resample.SetInterpolator(Interpolator)
+ resample.SetOutputOrigin(fixed.GetOrigin())
+ resample.SetOutputDirection(fixed.GetDirection())
+ resample.SetSize(image_size)
+
+ out = resample.Execute(fixed)
+
+ return out
+
+
+class BatchGenerator(DataLoaderBase):
+
+ def __init__(self, data, BATCH_SIZE, split_idx, seed, ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std,
+ ProbabilityTumorSlices=None, epoch=None, test=False):
+ super(self.__class__, self).__init__(data=data, BATCH_SIZE=BATCH_SIZE, seed=False, num_batches=None)
+ self._split_idx = split_idx
+ self._ProbabilityTumorSlices = ProbabilityTumorSlices
+ self._epoch = epoch
+ self._s = 0
+ self._count = 0
+ self.test = test
+ self.seed = seed
+ self.ADC_mean = ADC_mean
+ self.BVAL_mean = BVAL_mean
+ self.T2_mean = T2_mean
+ self.ADC_std = ADC_std
+ self.BVAL_std = BVAL_std
+ self.T2_std = T2_std
+
+ def generate_train_batch(self):
+
+ channels_img = 3
+ channels_label = 2
+ img_size = 200
+
+ if self.test == True:
+
+ img = np.empty((self.BATCH_SIZE, channels_img, img_size, img_size))
+ label = np.empty((self.BATCH_SIZE, channels_label, img_size, img_size))
+
+ if self._count < len(self._split_idx):
+ idx = self._split_idx[self._count]
+ z_dim = self._data[idx]['image'].shape[3]
+ self.BATCH_SIZE = z_dim
+
+ custom_batch = (z_dim / self.BATCH_SIZE)*self.BATCH_SIZE
+
+
+ if self._count < len(self._split_idx):
+ img = self._data[idx]['image']
+ img = img.transpose((3,0,1,2))
+ label = self._data[idx]['label']
+ label = label.transpose((3,0,1,2))
+ self._count += 1
+
+
+ else:
+ for b in range(self.BATCH_SIZE):
+ if self._s < custom_batch:
+ img[b, :, :, :] = self._data[idx]['image'][:channels_img, :, :, self._s]
+ label[b, :, :, :] = self._data[idx]['label'][:channels_label, :, :, self._s]
+ self._s += 1
+ else:
+ self._s = 0
+ self._count += 1
+ self._batches_generated = self._count
+ if self._count < len(self._split_idx):
+ idx = self._split_idx[self._batches_generated]
+ img[b, :, :, :] = self._data[idx]['image'][:channels_img, :, :, self._s]
+ label[b, :, :, :] = self._data[idx]['label'][:channels_label, :, :, self._s]
+ self._s += 1
+ else:
+ pass
+
+ else:
+ pass
+
+ else:
+
+ idx = np.random.choice(self._split_idx, self.BATCH_SIZE, False, None)
+ img = np.empty((self.BATCH_SIZE, channels_img, img_size, img_size))
+ label = np.empty((self.BATCH_SIZE, channels_label, img_size, img_size))
+
+ for b in range(self.BATCH_SIZE):
+
+ if self._ProbabilityTumorSlices is not None:
+ LabelData = self._data[idx[b]]['label']
+ z_dim = self._data[idx[b]]['image'].shape[3]
+ CancerSlices = []
+ for Slice in range(z_dim):
+
+ bool = np.where(LabelData[:, :, :, Slice] == 2, True, False)
+
+ if bool.any() == True:
+ CancerSlices.append(Slice)
+
+ if sum(CancerSlices) is not 0:
+ CancerSlices = np.array(CancerSlices)
+ totalTumorSliceProb = float(self._ProbabilityTumorSlices)
+ totalOtherSliceProb = float(1) - float(totalTumorSliceProb)
+
+ ProbabilityMap = np.array(np.zeros(z_dim))
+ ProbabilityMap[CancerSlices] = self._ProbabilityTumorSlices / float(len(CancerSlices))
+
+ NoTumorSlices = float(z_dim - len(CancerSlices))
+ ProbPerNoTumorSlice = totalOtherSliceProb / NoTumorSlices
+
+ ProbabilityMap = [ProbPerNoTumorSlice if g == 0 else g for g in ProbabilityMap]
+
+ else:
+ ProbabilityMap = np.zeros(z_dim)
+ ProbabilityMap = [float(1) / float(z_dim) if g == 0 else g for g in ProbabilityMap]
+
+ else:
+ ProbabilityMap = np.zeros(z_dim)
+ ProbabilityMap = [float(1) / float(z_dim) if g == 0 else g for g in ProbabilityMap]
+
+ randint = np.random.choice(z_dim, p=ProbabilityMap)
+
+ img[b, :, :, :] = self._data[idx[b]]['image'][:, :, :, randint]
+ label[b, :, :, :] = self._data[idx[b]]['label'][:, :, :, randint]
+
+ # cut off outliers before image normalization
+
+ img = np.nan_to_num(img)
+ img = cut_off_outliers(img, percentile_lower=0.2, percentile_upper=99.8, per_channel=True)
+
+ img[:, 0, :, :] = (img[:, 0, :, :] - np.float(self.ADC_mean)) / np.float(self.ADC_std)
+ img[:, 1, :, :] = (img[:, 1, :, :] - np.float(self.BVAL_mean)) / np.float(self.BVAL_std)
+ img[:, 2, :, :] = (img[:, 2, :, :] - np.float(self.T2_mean)) / np.float(self.T2_std)
+
+
+
+ img = np.float32(img)
+
+ data_dict = {"data": img, "seg": label}
+
+
+ return data_dict
+
+
+
+
+def CreateTrainValTestSplit(HistoFile_path, num_splits, num_val_folds, num_test_folds, seed):
+
+ np.random.seed(seed)
+
+ HistoFile = pd.read_csv(HistoFile_path)
+
+ # calculate random split assignments of the subjects
+ IDs = HistoFile.Master_ID.values # Patient no.
+ unique_IDs = np.unique(IDs)
+
+ num_subjects = len(unique_IDs)
+ splits = -1 * np.ones(num_subjects)
+ s_per_split = num_subjects // num_splits
+ # assign an equivalent # subjects to the splits
+ assign = np.random.choice(range(num_subjects), size=(num_splits, s_per_split), replace=False)
+ for split in range(num_splits):
+ for subj in range(s_per_split):
+ splits[assign[split, subj]] = split
+
+ # assign missing subjects
+ ixs = np.where(splits == -1)
+ splits[ixs] = np.random.randint(0, high=num_splits, size=len(ixs[0]))
+
+ train_splits = [s for s in range(num_splits)]
+ subjects = []
+ for s in train_splits:
+ split_ixs = np.where(splits == s)
+
+ split_subjs = unique_IDs[split_ixs]
+ subjects.append(split_subjs)
+
+ train_folds = [f for f in range(num_splits - num_val_folds - num_test_folds)]
+ train_ixs_lists = [subjects[fold] for fold in train_folds]
+ train_ixs = [ix for ixs in train_ixs_lists for ix in ixs] # flatten the list of lists
+
+ val_folds = [f for f in range(num_splits - num_test_folds) if f not in train_folds]
+ val_ixs_lists = [subjects[fold] for fold in val_folds]
+ val_ixs = [ix for ixs in val_ixs_lists for ix in ixs] # flatten the list of lists
+
+ test_folds = [f for f in range(num_splits) if f not in val_folds + train_folds]
+ test_ixs_lists = [subjects[fold] for fold in test_folds]
+ test_ixs = [ix for ixs in test_ixs_lists for ix in ixs] # flatten the list of lists
+
+
+ return train_ixs, val_ixs, test_ixs
+
+
+def get_class_frequencies(Data, train_idx, patch_size):
+
+ Tumor_frequencie_ADC = 0
+ Prostate_frequencie_ADC = 0
+ Background_frequencie_ADC = 0
+
+ Tumor_frequencie_T2 = 0
+ Prostate_frequencie_T2 = 0
+ Background_frequencie_T2 = 0
+
+ T2_mean = 0
+ T2_std = 0
+ ADC_mean = 0
+ ADC_std = 0
+ BVAL_mean = 0
+ BVAL_std = 0
+
+ for i in range(len(train_idx)):
+ idx = train_idx[i]
+ Data_class = Data[idx]['label']
+ Data_image = Data[idx]['image']
+
+ center_crop_dimensions = ((patch_size[0], patch_size[1]))
+
+ Data_class_Label = center_crop(Data_class, center_crop_dimensions)
+ Data_class_Label = Data_class_Label[0]
+
+ Data_image_cropped = center_crop(Data_image, center_crop_dimensions)
+ Data_image_cropped = Data_image_cropped[0]
+
+ Data_image_cropped = np.nan_to_num(Data_image_cropped)
+ Data_image_cropped = cut_off_outliers(Data_image_cropped, percentile_lower=0.2, percentile_upper=99.8, per_channel=True)
+
+ T2_mean += np.mean(Data_image_cropped[2,:,:,:])
+ T2_std += np.std(Data_image_cropped[2,:,:,:])
+ ADC_mean += np.mean(Data_image_cropped[0,:,:,:])
+ ADC_std += np.std(Data_image_cropped[0,:,:,:])
+ BVAL_mean += np.mean(Data_image_cropped[1,:,:,:])
+ BVAL_std += np.std(Data_image_cropped[1,:,:,:])
+
+
+ Tumor_frequencie_ADC += np.sum(Data_class_Label[0, :, :, :] == 2)
+ Prostate_frequencie_ADC += np.sum(Data_class_Label[0, :, :, :] == 1)
+ Background_frequencie_ADC += np.sum(Data_class_Label[0, :, :, :] == 0)
+
+ Tumor_frequencie_T2 += np.sum(Data_class_Label[1, :, :, :] == 2)
+ Prostate_frequencie_T2 += np.sum(Data_class_Label[1, :, :, :] == 1)
+ Background_frequencie_T2 += np.sum(Data_class_Label[1, :, :, :] == 0)
+
+ T2_mean = T2_mean / np.float(len(train_idx))
+ T2_std = T2_std / np.float(len(train_idx))
+ ADC_mean = ADC_mean / np.float(len(train_idx))
+ ADC_std = ADC_std / np.float(len(train_idx))
+ BVAL_mean = BVAL_mean / np.float(len(train_idx))
+ BVAL_std = BVAL_std / np.float(len(train_idx))
+
+ return Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC,\
+ Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2, ADC_mean, ADC_std, BVAL_mean, \
+ BVAL_std, T2_mean, T2_std
+
+
+def get_oversampling(Data, train_idx, Batch_Size, patch_size):
+
+ IDs = sorted(train_idx)
+
+ Total_Patients = len(IDs)
+
+ print 'Total_Patients', Total_Patients
+
+ Tumor_slices = 0
+ Prostate_slices = 0
+ Pos_Patient = 0
+ Non_Tumor_Slices = 0
+ Non_Prostate_Slices = 0
+
+ for i in range(len(IDs)):
+ idx = IDs[i]
+ print(idx)
+ z_Dim = (Data[idx]['label']).shape
+ Data_over = Data[idx]['label']
+
+ center_crop_dimensions = ((patch_size[0], patch_size[1]))
+ Data_over_cropped = center_crop(Data_over, center_crop_dimensions)
+ Data_over_cropped = Data_over_cropped[0]
+
+ for f in range(z_Dim[3]):
+ if np.sum(Data_over_cropped[0,:,:,f] == 2) >= 1:
+ Tumor_slices += 1
+ else:
+ Non_Tumor_Slices += 1
+
+ if np.sum(Data_over_cropped[0,:,:,f] == 1) >= 1:
+ Prostate_slices += 1
+ else:
+ Non_Prostate_Slices += 1
+
+ if np.sum(Data_over_cropped == 2) >= 1:
+ Pos_Patient += 1
+
+
+ print 'Tumor', Tumor_slices
+ print 'Non_Tumor', Non_Tumor_Slices
+ print 'Pos_Patients', Pos_Patient
+
+ print 'Prostate', Prostate_slices
+ print 'Non_Prostate_Slices', Non_Prostate_Slices
+
+
+
+ print(Tumor_slices, Non_Tumor_Slices)
+
+
+ Probability_Pos_Patient = Pos_Patient / np.float(Total_Patients)
+
+ print 'Probability_Pos_Patient', Probability_Pos_Patient
+
+ Slices_total = Tumor_slices + Non_Tumor_Slices
+
+ Natural_probability_tu_slice = Tumor_slices/ np.float(Slices_total)
+
+ Natural_probability_Prostate_slice = Prostate_slices / np.float(Slices_total)
+
+ print 'Natural_probability_tu_slice', Natural_probability_tu_slice
+ print 'Natural_probability_PRO_slice', Natural_probability_Prostate_slice
+
+ Oversampling_Factor = (Natural_probability_tu_slice ** (1/np.float(Batch_Size * Probability_Pos_Patient)))
+
+ print 'oversampling_factor', Oversampling_Factor
+
+ return (Oversampling_Factor), Slices_total, Natural_probability_tu_slice, Natural_probability_Prostate_slice
+
+
+
+
+def split_training(IDs, len_val, cv, cv_runs):
+
+ runs = cv_runs / 4
+ for s in range(runs):
+ if cv in range(s * 4, s * 4 + 4):
+ count = s
+ print count
+
+ np.random.seed(42 + count)
+
+ IDx = np.array(IDs)
+ np.random.shuffle(IDx)
+
+ factor = cv - count * 4
+
+ print factor
+ lower_limit = len_val * factor
+ print(lower_limit)
+
+ upper_limit = len_val * (factor + 1)
+
+ if factor == 3:
+ upper_limit = len(IDx)
+ print(upper_limit)
+
+ val_idx = IDx[lower_limit:upper_limit]
+ train_idx = np.setdiff1d(IDs, val_idx)
+
+ return train_idx, val_idx
+
+
+def save(Image, OutputFilePath, Mask = False):
+
+ if Mask == True:
+ sitk.WriteImage(sitk.Cast(Image, sitk.sitkUInt8), OutputFilePath)
+ else:
+ sitk.WriteImage(sitk.Cast(Image, sitk.sitkFloat32), OutputFilePath)
+
+ finalImg = sitk.ReadImage(OutputFilePath)
+ finalImg = sitk.GetArrayFromImage(finalImg)
+ finalImg = finalImg.swapaxes(0, 2)
+ finalImg_data, finalImg_options = nrrd.read(OutputFilePath)
+ finalImg_options['encoding'] = 'gzip'
+ nrrd.write(OutputFilePath, finalImg, header=finalImg_options)
\ No newline at end of file