diff --git a/brats_script.py b/brats_script.py new file mode 100644 index 0000000..0ab2cd1 --- /dev/null +++ b/brats_script.py @@ -0,0 +1,411 @@ +import matplotlib + +matplotlib.use("Agg", warn=True) + +import json +import math +import os + +import numpy as np +import torch +import torch.utils.data +from torch import nn, optim +import torch.distributions as dist +from torch.optim.lr_scheduler import StepLR + +from trixi.logger.experiment.pytorchexperimentlogger import PytorchExperimentLogger +from trixi.logger import PytorchVisdomLogger +from trixi.util import Config +from trixi.util.pytorchutils import set_seed + +from models.enc_dec import Encoder, Generator +from data.brain_ds import BrainDataSet +from utils.util import smooth_tensor, normalize, find_best_val, calc_hard_dice + + +class VAE(torch.nn.Module): + def __init__(self, input_size, h_size, z_dim, to_1x1=True, conv_op=torch.nn.Conv2d, + upsample_op=torch.nn.ConvTranspose2d, normalization_op=None, activation_op=torch.nn.LeakyReLU, + conv_params=None, activation_params=None, block_op=None, block_params=None, output_channels=None, + additional_input_slices=None, + *args, **kwargs): + + super(VAE, self).__init__() + + input_size_enc = list(input_size) + input_size_dec = list(input_size) + if output_channels is not None: + input_size_dec[0] = output_channels + if additional_input_slices is not None: + input_size_enc[0] += additional_input_slices * 2 + + self.encoder = Encoder(image_size=input_size_enc, h_size=h_size, z_dim=z_dim * 2, + normalization_op=normalization_op, to_1x1=to_1x1, conv_op=conv_op, + conv_params=conv_params, + activation_op=activation_op, activation_params=activation_params, block_op=block_op, + block_params=block_params) + self.decoder = Generator(image_size=input_size_dec, h_size=h_size[::-1], z_dim=z_dim, + normalization_op=normalization_op, to_1x1=to_1x1, upsample_op=upsample_op, + conv_params=conv_params, activation_op=activation_op, + activation_params=activation_params, block_op=block_op, + block_params=block_params) + + self.hidden_size = self.encoder.output_size + + def forward(self, inpt, sample=None, **kwargs): + enc = self.encoder(inpt, **kwargs) + + mu, log_std = torch.chunk(enc, 2, dim=1) + std = torch.exp(log_std) + z_dist = dist.Normal(mu, std) + + if sample or self.training: + z = z_dist.rsample() + else: + z = mu + + x_rec = self.decoder(z, **kwargs) + + return x_rec, mu, std + + def encode(self, inpt, **kwargs): + enc = self.encoder(inpt, **kwargs) + mu, log_std = torch.chunk(enc, 2, dim=1) + return mu, log_std + + def decode(self, inpt, **kwargs): + x_rec = self.decoder(inpt, **kwargs) + return x_rec + + +def loss_function(recon_x, x, mu, logstd, rec_log_std=0, sum_samplewise=True): + rec_std = math.exp(rec_log_std) + rec_var = rec_std ** 2 + + x_dist = dist.Normal(recon_x, rec_std) + log_p_x_z = x_dist.log_prob(x) + if sum_samplewise: + log_p_x_z = torch.sum(log_p_x_z, dim=(1, 2, 3)) + + z_prior = dist.Normal(0, 1.) + z_post = dist.Normal(mu, torch.exp(logstd)) + + kl_div = dist.kl_divergence(z_post, z_prior) + if sum_samplewise: + kl_div = torch.sum(kl_div, dim=(1, 2, 3)) + + if sum_samplewise: + loss = torch.mean(kl_div - log_p_x_z) + else: + loss = torch.mean(torch.sum(kl_div, dim=(1, 2, 3)) - torch.sum(log_p_x_z, dim=(1, 2, 3))) + + return loss, kl_div, -log_p_x_z + + +def get_inpt_grad(model, inpt, err_fn): + model.zero_grad() + inpt = inpt.detach() + inpt.requires_grad = True + + err = err_fn(inpt) + err.backward() + + grad = inpt.grad.detach() + + model.zero_grad() + + return torch.abs(grad.detach()) + + +def train(epoch, model, optimizer, train_loader, device, vlog, elog, log_var_std): + model.train() + train_loss = 0 + for batch_idx, data in enumerate(train_loader): + data = data["data"][0].float().to(device) + optimizer.zero_grad() + recon_batch, mu, logstd = model(data) + loss, kl, rec = loss_function(recon_batch, data, mu, logstd, log_var_std) + loss.backward() + train_loss += loss.item() + optimizer.step() + if batch_idx % 100 == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx, len(train_loader), + 100. * batch_idx / len(train_loader), + loss.item() / len(data))) + vlog.show_value(torch.mean(kl).item(), name="Kl-loss", tag="Losses") + vlog.show_value(torch.mean(rec).item(), name="Rec-loss", tag="Losses") + vlog.show_value(loss.item(), name="Total-loss", tag="Losses") + + print('====> Epoch: {} Average loss: {:.4f}'.format( + epoch, train_loss / len(train_loader))) + + +def test_slice(model, test_loader, test_loader_abnorm, device, vlog, elog, image_size, batch_size, log_var_std): + model.eval() + test_loss = [] + kl_loss = [] + rec_loss = [] + with torch.no_grad(): + for i, data in enumerate(test_loader): + data = data["data"][0].float().to(device) + recon_batch, mu, logstd = model(data) + loss, kl, rec = loss_function(recon_batch, data, mu, logstd, log_var_std) + test_loss += (kl + rec).tolist() + kl_loss += kl.tolist() + rec_loss += rec.tolist() + # if i == 0: + # n = min(data.size(0), 8) + # comparison = torch.cat([data[:n], + # recon_batch[:n]]) + # vlog.show_image_grid(comparison.cpu(), name='reconstruction') + + vlog.show_value(np.mean(kl_loss), name="Norm-Kl-loss", tag="Anno") + vlog.show_value(np.mean(rec_loss), name="Norm-Rec-loss", tag="Anno") + vlog.show_value(np.mean(test_loss), name="Norm-Total-loss", tag="Anno") + elog.show_value(np.mean(kl_loss), name="Norm-Kl-loss", tag="Anno") + elog.show_value(np.mean(rec_loss), name="Norm-Rec-loss", tag="Anno") + elog.show_value(np.mean(test_loss), name="Norm-Total-loss", tag="Anno") + + test_loss_ab = [] + kl_loss_ab = [] + rec_loss_ab = [] + with torch.no_grad(): + for i, data in enumerate(test_loader_abnorm): + data = data["data"][0].float().to(device) + recon_batch, mu, logstd = model(data) + loss, kl, rec = loss_function(recon_batch, data, mu, logstd, log_var_std) + test_loss_ab += (kl + rec).tolist() + kl_loss_ab += kl.tolist() + rec_loss_ab += rec.tolist() + # if i == 0: + # n = min(data.size(0), 8) + # comparison = torch.cat([data[:n], + # recon_batch[:n]]) + # vlog.show_image_grid(comparison.cpu(), name='reconstruction2') + + elog.print('====> Test set loss: {:.4f}'.format(np.mean(test_loss))) + + vlog.show_value(np.mean(kl_loss_ab), name="Unorm-Kl-loss", tag="Anno") + vlog.show_value(np.mean(rec_loss_ab), name="Unorm-Rec-loss", tag="Anno") + vlog.show_value(np.mean(test_loss_ab), name="Unorm-Total-loss", tag="Anno") + elog.show_value(np.mean(kl_loss_ab), name="Unorm-Kl-loss", tag="Anno") + elog.show_value(np.mean(rec_loss_ab), name="Unorm-Rec-loss", tag="Anno") + elog.show_value(np.mean(test_loss_ab), name="Unorm-Total-loss", tag="Anno") + + kl_roc, kl_pr = elog.get_classification_metrics(kl_loss + kl_loss_ab, + [0] * len(kl_loss) + [1] * len(kl_loss_ab), + )[0] + rec_roc, rec_pr = elog.get_classification_metrics(rec_loss + rec_loss_ab, + [0] * len(rec_loss) + [1] * len(rec_loss_ab), + )[0] + loss_roc, loss_pr = elog.get_classification_metrics(test_loss + test_loss_ab, + [0] * len(test_loss) + [1] * len(test_loss_ab), + )[0] + + vlog.show_value(np.mean(kl_roc), name="KL-loss", tag="ROC") + vlog.show_value(np.mean(rec_roc), name="Rec-loss", tag="ROC") + vlog.show_value(np.mean(loss_roc), name="Total-loss", tag="ROC") + elog.show_value(np.mean(kl_roc), name="KL-loss", tag="ROC") + elog.show_value(np.mean(rec_roc), name="Rec-loss", tag="ROC") + elog.show_value(np.mean(loss_roc), name="Total-loss", tag="ROC") + + vlog.show_value(np.mean(kl_pr), name="KL-loss", tag="PR") + vlog.show_value(np.mean(rec_pr), name="Rec-loss", tag="PR") + vlog.show_value(np.mean(loss_pr), name="Total-loss", tag="PR") + + return kl_roc, rec_roc, loss_roc, kl_pr, rec_pr, loss_pr, np.mean(test_loss) + + +def test_pixel(model, test_loader_pixel, device, vlog, elog, image_size, batch_size, log_var_std): + model.eval() + + test_loss = [] + kl_loss = [] + rec_loss = [] + + pixel_class = [] + pixel_rec_err = [] + pixel_grad_all = [] + pixel_grad_kl = [] + pixel_grad_rec = [] + pixel_combi_err = [] + + with torch.no_grad(): + for i, data in enumerate(test_loader_pixel): + inpt = data["data"][0].float().to(device) + seg = data["seg"].float()[0, :, 0] + seg_flat = seg.flatten() > 0.5 + pixel_class += seg_flat.tolist() + + recon_batch, mu, logstd = model(inpt) + + loss, kl, rec = loss_function(recon_batch, inpt, mu, logstd, log_var_std, sum_samplewise=False) + rec = rec.detach().cpu() + pixel_rec_err += rec.flatten().tolist() + + def __err_fn_all(x, loss_idx=0): # loss_idx 0: elbo, 1: kl part, 2: rec part + outpt = model(x) + recon_batch, mu, logstd = outpt + loss = loss_function(recon_batch, x, mu, logstd, log_var_std) + return torch.mean(loss[loss_idx]) + + with torch.enable_grad(): + loss_grad_all = get_inpt_grad(model=model, inpt=inpt, err_fn=lambda x: __err_fn_all(x, 0), + ).detach().cpu() + loss_grad_kl = get_inpt_grad(model=model, inpt=inpt, err_fn=lambda x: __err_fn_all(x, 1), + ).detach().cpu() + loss_grad_rec = get_inpt_grad(model=model, inpt=inpt, err_fn=lambda x: __err_fn_all(x, 2), + ).detach().cpu() + + pixel_grad_all += smooth_tensor(loss_grad_all).flatten().tolist() + pixel_grad_kl += smooth_tensor(loss_grad_kl).flatten().tolist() + pixel_grad_rec += smooth_tensor(loss_grad_rec).flatten().tolist() + + pixel_combi_err += (smooth_tensor(normalize(loss_grad_kl)) * rec).flatten().tolist() + + kl_normalized = np.asarray(pixel_grad_kl) + kl_normalized = (kl_normalized - np.min(kl_normalized)) / (np.max(kl_normalized) - np.min(kl_normalized)) + rec_normalized = np.asarray(pixel_rec_err) + rec_normalized = (rec_normalized - np.min(rec_normalized)) / (np.max(rec_normalized) - np.min(rec_normalized)) + combi_add = kl_normalized + rec_normalized + + rec_err_roc, rec_err_pr = elog.get_classification_metrics(pixel_rec_err, pixel_class)[0] + grad_all_roc, grad_all_pr = elog.get_classification_metrics(pixel_grad_all, pixel_class)[0] + grad_kl_roc, grad_kl_pr = elog.get_classification_metrics(pixel_grad_kl, pixel_class)[0] + grad_rec_roc, grad_rec_pr = elog.get_classification_metrics(pixel_grad_rec, pixel_class)[0] + pixel_combi_roc, pixel_combi_pr = elog.get_classification_metrics(pixel_combi_err, pixel_class)[0] + add_combi_roc, add_combi_pr = elog.get_classification_metrics(combi_add, pixel_class)[0] + + rec_err_dice, reconst_thres = find_best_val(pixel_rec_err, pixel_class, calc_hard_dice, max_steps=8, + val_range=(0, np.max(pixel_rec_err))) + grad_kl_dice, grad_kl_thres = find_best_val(pixel_grad_kl, pixel_class, calc_hard_dice, max_steps=8, + val_range=(0, np.max(pixel_grad_kl))) + pixel_combi_dice, pixel_combi_thres = find_best_val(pixel_combi_err, pixel_class, calc_hard_dice, max_steps=8, + val_range=(0, np.max(pixel_combi_err))) + add_combi_dice, _ = find_best_val(combi_add, pixel_class, calc_hard_dice, max_steps=8, + val_range=(0, np.max(combi_add))) + + with open(os.path.join(elog.result_dir, "pixel.json"), "a+") as file_: + json.dump({ + "rec_err_roc": rec_err_roc, "rec_err_pr": rec_err_pr, + "grad_all_roc": grad_all_roc, "grad_all_pr": grad_all_pr, + "grad_kl_roc": grad_kl_roc, "grad_kl_pr": grad_kl_pr, + "grad_rec_roc": grad_rec_roc, "grad_rec_pr": grad_rec_pr, + "pixel_combi_roc": pixel_combi_roc, "pixel_combi_pr": pixel_combi_pr, + "rec_err_dice": rec_err_dice, "grad_kl_dice": grad_kl_dice, "pixel_combi_dice": + pixel_combi_dice, + + }, file_, indent=4) + + +def model_run(patch_size, batch_size, odd_class, z, seed=123, log_var_std=0, n_epochs=5, + model_h_size=(16, 32, 64, 256), exp_name="exp", folder_name="exp"): + set_seed(seed) + + config = Config( + patch_size=patch_size, batch_size=batch_size, odd_class=odd_class, z=z, seed=seed, log_var_std=log_var_std, + n_epochs=n_epochs + ) + + device = torch.device("cuda") + + datasets_common_args = { + "batch_size": batch_size, + "target_size": patch_size, + "input_slice": [1, ], + "add_noise": True, + "mask_type": "gaussian", # 0.0, ## TODO + "elastic_deform": False, + "rnd_crop": True, + "rotate": True, + "color_augment": True, + "add_slices": 0, + } + + input_shape = ( + datasets_common_args["batch_size"], 1, datasets_common_args["target_size"], datasets_common_args["target_size"]) + + train_set_args = { + "base_dir": "hcp/", + # "num_batches": 500, + "slice_offset": 20, + "num_processes": 8, + } + test_set_normal_args = { + "base_dir": "brats17/", + # "num_batches": 100, + "do_reshuffle": False, + "mode": "val", + "num_processes": 2, + "slice_offset": 20, + "label_slice": 2, + "only_labeled_slices": False, + } + test_set_unormal_args = { + "base_dir": "brats17/", + # "num_batches": 100, + "do_reshuffle": False, + "mode": "val", + "num_processes": 2, + "slice_offset": 20, + "label_slice": 2, + "only_labeled_slices": True, + "labeled_threshold": 10, + } + test_set_all_args = { + "base_dir": "brats17_test/", + # "num_batches": 50, + "do_reshuffle": False, + "mode": "val", + "num_processes": 2, + "slice_offset": 20, + "label_slice": 2, + } + + train_loader = BrainDataSet(**datasets_common_args, **train_set_args) + test_loader_normal = BrainDataSet(**datasets_common_args, **test_set_normal_args) + test_loader_abnorm = BrainDataSet(**datasets_common_args, **test_set_unormal_args) + test_loader_all = BrainDataSet(**datasets_common_args, **test_set_all_args) + + model = VAE(input_size=input_shape[1:], h_size=model_h_size, z_dim=z).to(device) + optimizer = optim.Adam(model.parameters(), lr=1e-4) + lr_scheduler = StepLR(optimizer, step_size=1) + + vlog = PytorchVisdomLogger(exp_name=exp_name) + elog = PytorchExperimentLogger(base_dir=folder_name, exp_name=exp_name) + + elog.save_config(config, "config") + + for epoch in range(1, n_epochs + 1): + train(epoch, model, optimizer, train_loader, device, vlog, elog, log_var_std) + + kl_roc, rec_roc, loss_roc, kl_pr, rec_pr, loss_pr, test_loss = test_slice(model, test_loader_normal, + test_loader_abnorm, device, + vlog, elog, input_shape, batch_size, + log_var_std) + + with open(os.path.join(elog.result_dir, "results.json"), "w") as file_: + json.dump({ + "kl_roc": kl_roc, "rec_roc": rec_roc, "loss_roc": loss_roc, + "kl_pr": kl_pr, "rec_pr": rec_pr, "loss_pr": loss_pr, + }, file_, indent=4) + + elog.save_model(model, "vae") + + test_pixel(model, test_loader_all, device, vlog, elog, input_shape, batch_size, log_var_std) + + print("All done....") + + +if __name__ == '__main__': + torch.backends.cudnn.benchmark = True + + patch_size = 64 + batch_size = 64 + odd_class = 0 + z = 256 + seed = 123 + log_var_std = 0. + + model_run(patch_size, batch_size, odd_class, z, seed, log_var_std) diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/__pycache__/__init__.cpython-36.pyc b/data/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..84b8226 Binary files /dev/null and b/data/__pycache__/__init__.cpython-36.pyc differ diff --git a/data/__pycache__/brain_ds.cpython-36.pyc b/data/__pycache__/brain_ds.cpython-36.pyc new file mode 100644 index 0000000..57f1240 Binary files /dev/null and b/data/__pycache__/brain_ds.cpython-36.pyc differ diff --git a/data/__pycache__/data_loader.cpython-36.pyc b/data/__pycache__/data_loader.cpython-36.pyc new file mode 100644 index 0000000..c259321 Binary files /dev/null and b/data/__pycache__/data_loader.cpython-36.pyc differ diff --git a/data/brain_ds.py b/data/brain_ds.py new file mode 100644 index 0000000..0ecece2 --- /dev/null +++ b/data/brain_ds.py @@ -0,0 +1,268 @@ +import fnmatch +import os +import random +import shutil +import string +from collections import defaultdict +from time import sleep + +import numpy as np + +from batchgenerators.dataloading.data_loader import DataLoaderBase, SlimDataLoaderBase +from batchgenerators.transforms import BrightnessMultiplicativeTransform, BrightnessTransform, GaussianNoiseTransform, \ + MirrorTransform, SpatialTransform +from batchgenerators.transforms.abstract_transforms import Compose, RndTransform +from batchgenerators.transforms.color_transforms import ClipValueRange +from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform, PadTransform, FillupPadTransform +from batchgenerators.transforms.noise_transforms import BlankSquareNoiseTransform, GaussianBlurTransform, \ + SquareMaskTransform +from batchgenerators.transforms.spatial_transforms import ResizeTransform, ZoomTransform +from batchgenerators.transforms.utility_transforms import AddToDictTransform, CopyTransform, NumpyToTensor, \ + ReshapeTransform + +from data.data_loader import MultiThreadedDataLoader + + +def load_dataset(base_dir, pattern='*.npy', slice_offset=0, only_labeled_slices=None, label_slice=None, + labeled_threshold=10): + fls = [] + files_len = [] + slices = [] + + for root, dirs, files in os.walk(base_dir): + for i, filename in enumerate(sorted(fnmatch.filter(files, pattern))): + npy_file = os.path.join(root, filename) + numpy_array = np.load(npy_file, mmap_mode="r") + + fls.append(npy_file) + files_len.append(numpy_array.shape[1]) + + if only_labeled_slices is None: + + slices.extend([(i, j) for j in range(slice_offset, files_len[-1] - slice_offset)]) + else: + assert label_slice is not None + + for s_idx in range(slice_offset, numpy_array.shape[1] - slice_offset): + + pixel_sum = np.sum(numpy_array[label_slice, s_idx] > 0.1) + if pixel_sum > labeled_threshold: + if only_labeled_slices is True: + slices.append((i, s_idx)) + elif pixel_sum == 0: + if only_labeled_slices is False: + slices.append((i, s_idx)) + + return fls, files_len, slices + + +def get_transforms(mode="train", n_channels=1, target_size=128, add_resize=False, add_noise=False, mask_type="", + batch_size=16, rotate=True, elastic_deform=True, rnd_crop=False, color_augment=True): + tranform_list = [] + noise_list = [] + + if mode == "train": + + tranform_list = [FillupPadTransform(min_size=(n_channels, target_size + 5, target_size + 5)), + ResizeTransform(target_size=(target_size + 1, target_size + 1), + order=1, concatenate_list=True), + + # RandomCropTransform(crop_size=(target_size + 5, target_size + 5)), + MirrorTransform(axes=(2,)), + ReshapeTransform(new_shape=(1, -1, "h", "w")), + SpatialTransform(patch_size=(target_size, target_size), random_crop=rnd_crop, + patch_center_dist_from_border=target_size // 2, + do_elastic_deform=elastic_deform, alpha=(0., 100.), sigma=(10., 13.), + do_rotation=rotate, + angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8), + scale=(0.9, 1.2), + border_mode_data="nearest", border_mode_seg="nearest"), + ReshapeTransform(new_shape=(batch_size, -1, "h", "w"))] + if color_augment: + tranform_list += [ # BrightnessTransform(mu=0, sigma=0.2), + BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1))] + + tranform_list += [ + GaussianNoiseTransform(noise_variance=(0., 0.05)), + ClipValueRange(min=-1.5, max=1.5), + ] + + noise_list = [] + if mask_type == "gaussian": + noise_list += [GaussianNoiseTransform(noise_variance=(0., 0.2))] + + + elif mode == "val": + tranform_list = [FillupPadTransform(min_size=(n_channels, target_size + 5, target_size + 5)), + ResizeTransform(target_size=(target_size + 1, target_size + 1), + order=1, concatenate_list=True), + CenterCropTransform(crop_size=(target_size, target_size)), + ClipValueRange(min=-1.5, max=1.5), + # BrightnessTransform(mu=0, sigma=0.2), + # BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1)), + CopyTransform({"data": "data_clean"}, copy=True) + ] + + + noise_list += [] + + if add_noise: + tranform_list = tranform_list + noise_list + + + tranform_list.append(NumpyToTensor()) + + return Compose(tranform_list) + + +class BrainDataSet(object): + def __init__(self, base_dir, mode="train", batch_size=16, num_batches=None, seed=None, + num_processes=8, num_cached_per_queue=8 * 4, target_size=128, file_pattern='*.npy', + rescale_data=False, add_noise=False, label_slice=None, input_slice=(0,), mask_type="", + slice_offset=0, do_reshuffle=True, only_labeled_slices=None, labeled_threshold=10, + rotate=True, elastic_deform=True, rnd_crop=True, color_augment=True, tmp_dir=None, use_npz=False, + add_slices=0): + data_loader = BrainDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, + num_batches=num_batches, seed=seed, file_pattern=file_pattern, + input_slice=input_slice, label_slice=label_slice, slice_offset=slice_offset, + only_labeled_slices=only_labeled_slices, labeled_threshold=labeled_threshold, + tmp_dir=tmp_dir, use_npz=use_npz, add_slices=add_slices) + + self.data_loader = data_loader + self.batch_size = batch_size + self.do_reshuffle = do_reshuffle + self.n_channels = (add_slices * 2 + 1) + self.transforms = get_transforms(mode=mode, target_size=target_size, add_resize=rescale_data, + add_noise=add_noise, mask_type=mask_type, batch_size=batch_size, + rotate=rotate, elastic_deform=elastic_deform, rnd_crop=rnd_crop, + color_augment=color_augment, n_channels=self.n_channels) + self.agumenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=num_processes, + num_cached_per_queue=num_cached_per_queue, seeds=seed, + shuffle=do_reshuffle) + self.agumenter.restart() + self.first = True + + def __len__(self): + return len(self.data_loader) + + def __iter__(self): + if self.do_reshuffle: + self.data_loader.reshuffle() + self.agumenter.renew() + return iter(self.agumenter) + + # def __next__(self): + # return next(self.agumenter) + + def __getitem__(self, index): + item = self.data_loader[index] + item = self.transforms(**item) + return item + + +class BrainDataLoader(SlimDataLoaderBase): + def __init__(self, base_dir, mode="train", batch_size=16, num_batches=None, + seed=None, file_pattern='*.npy', label_slice=None, input_slice=(0,), slice_offset=0, + only_labeled_slices=None, labeled_threshold=10, tmp_dir=None, use_npz=False, add_slices=0): + + self.files, self.file_len, self.slices = load_dataset(base_dir=base_dir, pattern=file_pattern, + slice_offset=slice_offset + add_slices, + only_labeled_slices=only_labeled_slices, + label_slice=label_slice, + labeled_threshold=labeled_threshold) + super(SlimDataLoaderBase, self).__init__() + + self.batch_size = batch_size + self.tmp_dir = tmp_dir + self.use_npz = use_npz + if self.tmp_dir is not None and self.tmp_dir != "" and self.tmp_dir != "None": + rnd_str = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(15)) + self.tmp_dir = os.path.join(self.tmp_dir, rnd_str) + if not os.path.exists(self.tmp_dir): + os.mkdir(self.tmp_dir) + + self.use_next = False + if mode == "train": + self.use_next = False + + self.slice_idxs = list(range(0, len(self.slices))) + self.data_len = len(self.slices) + if num_batches is None: + self.n_items = self.data_len // self.batch_size + self.num_batches = self.data_len // self.batch_size + else: + self.num_batches = num_batches + self.n_items = min(self.data_len // self.batch_size, self.num_batches) + + if isinstance(label_slice, int): + label_slice = (label_slice,) + self.input_slice = input_slice + self.label_slice = label_slice + + self.add_slices = add_slices + + # print(self.slice_idxs) + + self.np_data = np.asarray(self.slices) + + def reshuffle(self): + print("Reshuffle...") + random.shuffle(self.slice_idxs) + + def generate_train_batch(self): + open_arr = random.sample(self._data, self.batch_size) + return self.get_data_from_array(open_arr) + + def __len__(self): + n_items = min(self.data_len // self.batch_size, self.num_batches) + return n_items + + def __getitem__(self, item): + + if item > self.n_items: + raise StopIteration() + + start_idx = (item * self.batch_size) % self.data_len + stop_idx = ((item + 1) * self.batch_size) % self.data_len + + if stop_idx > start_idx: + idxs = self.slice_idxs[start_idx:stop_idx] + else: + raise StopIteration() + idxs = self.slice_idxs[:stop_idx] + self.slice_idxs[start_idx:] + + open_arr = self.np_data[idxs] + return self.get_data_from_array(open_arr) + + def get_data_from_array(self, open_array): + + data = [] + fnames = [] + slice_idxs = [] + labels = [] + + for slice in open_array: + fn_name = self.files[slice[0]] + slice_idx = slice[1] + + numpy_array = np.load(fn_name, mmap_mode="r") + numpy_slice = numpy_array[self.input_slice, slice_idx - self.add_slices:slice_idx + self.add_slices + 1, ] + + data.append(numpy_slice) + + if self.label_slice is not None: + label_slice = numpy_array[self.label_slice, + slice_idx - self.add_slices:slice_idx + self.add_slices + 1, ] + labels.append(label_slice) + + fnames.append(fn_name) + slice_idxs.append(slice_idx / 200.) + del numpy_array + + ret_dict = {'data': data, 'fnames': fnames, 'slice_idxs': slice_idxs} + if self.label_slice is not None: + ret_dict['seg'] = labels + + return ret_dict + + diff --git a/data/data_loader.py b/data/data_loader.py new file mode 100644 index 0000000..c8cfc87 --- /dev/null +++ b/data/data_loader.py @@ -0,0 +1,68 @@ +import warnings +from torch.utils.data import DataLoader, Dataset + + +class WrappedDataset(Dataset): + def __init__(self, dataset, transforms): + self.transforms = transforms + self.dataset = dataset + + self.is_indexable = False + if hasattr(self.dataset, "__getitem__") and not (hasattr(self.dataset, "use_next") and self.dataset.use_next \ + is True): + self.is_indexable = True + + def __getitem__(self, index): + + if not self.is_indexable: + item = next(self.dataset) + else: + item = self.dataset[index] + + item = self.transforms(**item) + return item + + def __len__(self): + return len(self.dataset) + + +class MultiThreadedDataLoader(object): + def __init__(self, data_loader, transform, num_processes, shuffle=True, timeout=120, **kwargs): + self.transform = transform + self.timeout = timeout + + self.cntr = 1 + self.ds_wrapper = WrappedDataset(data_loader, transform) + + self.generator = DataLoader(self.ds_wrapper, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, + num_workers=num_processes, pin_memory=False, drop_last=False, timeout=0) + + self.num_processes = num_processes + self.iter = None + + def __iter__(self): + self.iter = iter(self.generator) + return self + + def del_iter(self): + del self.iter + + def __next__(self): + try: + return next(self.iter) + except RuntimeError: + print("Queue is empty, None returned") + warnings.warn("Queue is empty, None returned") + raise StopIteration + return + + def renew(self): + if self.cntr > 1: + self.generator.timeout = self.timeout + self.cntr += 1 + + def restart(self): + pass + + def kill_iterator(self): + pass diff --git a/mnist_script.py b/mnist_script.py new file mode 100644 index 0000000..b1f6fb6 --- /dev/null +++ b/mnist_script.py @@ -0,0 +1,245 @@ +import matplotlib + +matplotlib.use("Agg", warn=True) + +import json +import math +import os +from collections import defaultdict + +import numpy as np +import torch +import torch.utils.data +from torch import nn, optim +from torch.nn import functional as F +import torch.distributions as dist + +from torchvision import datasets, transforms +from torchvision.utils import save_image + +from trixi.logger.experiment.pytorchexperimentlogger import PytorchExperimentLogger +from trixi.logger import PytorchVisdomLogger +from trixi.util import Config +from trixi.util.pytorchutils import set_seed + + +class VAE(nn.Module): + def __init__(self, z=20, input_size=784): + super(VAE, self).__init__() + + self.fc1 = nn.Linear(input_size, 400) + self.fc21 = nn.Linear(400, z) + self.fc22 = nn.Linear(400, z) + self.fc3 = nn.Linear(z, 400) + self.fc4 = nn.Linear(400, input_size) + + def encode(self, x): + h1 = F.relu(self.fc1(x)) + return self.fc21(h1), self.fc22(h1) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def decode(self, z): + h3 = F.relu(self.fc3(z)) + return torch.sigmoid(self.fc4(h3)) + + def forward(self, x): + mu, logstd = self.encode(x) + z = self.reparameterize(mu, logstd) + return self.decode(z), mu, logstd + + +def loss_function(recon_x, x, mu, logstd, rec_log_std=0): + rec_std = math.exp(rec_log_std) + rec_var = rec_std ** 2 + + x_dist = dist.Normal(recon_x, rec_std) + log_p_x_z = torch.sum(x_dist.log_prob(x), dim=1) + + z_prior = dist.Normal(0, 1.) + z_post = dist.Normal(mu, torch.exp(logstd)) + kl_div = torch.sum(dist.kl_divergence(z_post, z_prior), dim=1) + + return torch.mean(kl_div - log_p_x_z), kl_div, -log_p_x_z + + +def train(epoch, model, optimizer, train_loader, device, scaling, vlog, elog, log_var_std): + model.train() + train_loss = 0 + for batch_idx, (data, _) in enumerate(train_loader): + data = data.to(device) + data_flat = data.flatten(start_dim=1).repeat(1, scaling) + optimizer.zero_grad() + recon_batch, mu, logvar = model(data_flat) + loss, kl, rec = loss_function(recon_batch, data_flat, mu, logvar, log_var_std) + loss.backward() + train_loss += loss.item() + optimizer.step() + if batch_idx % 10 == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), + loss.item() / len(data))) + # vlog.show_value(torch.mean(kl).item(), name="Kl-loss", tag="Losses") + # vlog.show_value(torch.mean(rec).item(), name="Rec-loss", tag="Losses") + # vlog.show_value(loss.item(), name="Total-loss", tag="Losses") + + print('====> Epoch: {} Average loss: {:.4f}'.format( + epoch, train_loss / len(train_loader.dataset))) + + +def test(model, test_loader, test_loader_abnorm, device, scaling, vlog, elog, image_size, batch_size, log_var_std): + model.eval() + test_loss = [] + kl_loss = [] + rec_loss = [] + with torch.no_grad(): + for i, (data, _) in enumerate(test_loader): + data = data.to(device) + data_flat = data.flatten(start_dim=1).repeat(1, scaling) + recon_batch, mu, logvar = model(data_flat) + loss, kl, rec = loss_function(recon_batch, data_flat, mu, logvar, log_var_std) + test_loss += (kl + rec).tolist() + kl_loss += kl.tolist() + rec_loss += rec.tolist() + if i == 0: + n = min(data.size(0), 8) + comparison = torch.cat([data[:n], + recon_batch[:, :image_size].view(batch_size, 1, 28, 28)[:n]]) + # vlog.show_image_grid(comparison.cpu(), name='reconstruction') + + # vlog.show_value(np.mean(kl_loss), name="Norm-Kl-loss", tag="Anno") + # vlog.show_value(np.mean(rec_loss), name="Norm-Rec-loss", tag="Anno") + # vlog.show_value(np.mean(test_loss), name="Norm-Total-loss", tag="Anno") + # elog.show_value(np.mean(kl_loss), name="Norm-Kl-loss", tag="Anno") + # elog.show_value(np.mean(rec_loss), name="Norm-Rec-loss", tag="Anno") + # elog.show_value(np.mean(test_loss), name="Norm-Total-loss", tag="Anno") + + test_loss_ab = [] + kl_loss_ab = [] + rec_loss_ab = [] + with torch.no_grad(): + for i, (data, _) in enumerate(test_loader_abnorm): + data = data.to(device) + data_flat = data.flatten(start_dim=1).repeat(1, scaling) + recon_batch, mu, logvar = model(data_flat) + loss, kl, rec = loss_function(recon_batch, data_flat, mu, logvar, log_var_std) + test_loss_ab += (kl + rec).tolist() + kl_loss_ab += kl.tolist() + rec_loss_ab += rec.tolist() + if i == 0: + n = min(data.size(0), 8) + comparison = torch.cat([data[:n], + recon_batch[:, :image_size].view(batch_size, 1, 28, 28)[:n]]) + # vlog.show_image_grid(comparison.cpu(), name='reconstruction2') + + print('====> Test set loss: {:.4f}'.format(np.mean(test_loss))) + + # vlog.show_value(np.mean(kl_loss_ab), name="Unorm-Kl-loss", tag="Anno") + # vlog.show_value(np.mean(rec_loss_ab), name="Unorm-Rec-loss", tag="Anno") + # vlog.show_value(np.mean(test_loss_ab), name="Unorm-Total-loss", tag="Anno") + # elog.show_value(np.mean(kl_loss_ab), name="Unorm-Kl-loss", tag="Anno") + # elog.show_value(np.mean(rec_loss_ab), name="Unorm-Rec-loss", tag="Anno") + # elog.show_value(np.mean(test_loss_ab), name="Unorm-Total-loss", tag="Anno") + + kl_roc, kl_pr = elog.get_classification_metrics(kl_loss + kl_loss_ab, + [0] * len(kl_loss) + [1] * len(kl_loss_ab), + )[0] + rec_roc, rec_pr = elog.get_classification_metrics(rec_loss + rec_loss_ab, + [0] * len(rec_loss) + [1] * len(rec_loss_ab), + )[0] + loss_roc, loss_pr = elog.get_classification_metrics(test_loss + test_loss_ab, + [0] * len(test_loss) + [1] * len(test_loss_ab), + )[0] + + # vlog.show_value(np.mean(kl_roc), name="KL-loss", tag="ROC") + # vlog.show_value(np.mean(rec_roc), name="Rec-loss", tag="ROC") + # vlog.show_value(np.mean(loss_roc), name="Total-loss", tag="ROC") + # elog.show_value(np.mean(kl_roc), name="KL-loss", tag="ROC") + # elog.show_value(np.mean(rec_roc), name="Rec-loss", tag="ROC") + # elog.show_value(np.mean(loss_roc), name="Total-loss", tag="ROC") + + # vlog.show_value(np.mean(kl_pr), name="KL-loss", tag="PR") + # vlog.show_value(np.mean(rec_pr), name="Rec-loss", tag="PR") + # vlog.show_value(np.mean(loss_pr), name="Total-loss", tag="PR") + + return kl_roc, rec_roc, loss_roc, kl_pr, rec_pr, loss_pr + + +def model_run(scaling, batch_size, odd_class, z, seed=123, log_var_std=0, n_epochs=25): + set_seed(seed) + + config = Config( + scaling=scaling, batch_size=batch_size, odd_class=odd_class, z=z, seed=seed, log_var_std=log_var_std, + n_epochs=n_epochs + ) + + image_size = 784 + input_size = image_size * scaling + device = torch.device("cuda") + + def get_same_index(ds, label, invert=False): + label_indices = [] + for i in range(len(ds)): + if invert: + if ds[i][1] != label: + label_indices.append(i) + if not invert: + if ds[i][1] == label: + label_indices.append(i) + return label_indices + + kwargs = {'num_workers': 1, 'pin_memory': True} + train_set = datasets.FashionMNIST('/home/david/data/datasets/fashion_mnist', train=True, download=True, + transform=transforms.ToTensor()) + test_set = datasets.FashionMNIST('/home/david/data/datasets/fashion_mnist', train=False, + transform=transforms.ToTensor()) + + train_indices_zero = get_same_index(train_set, odd_class, invert=True) + train_zero_set = torch.utils.data.sampler.SubsetRandomSampler(train_indices_zero) + test_indices_zero = get_same_index(test_set, odd_class, invert=True) + test_zero_set = torch.utils.data.sampler.SubsetRandomSampler(test_indices_zero) + test_indices_ones = get_same_index(test_set, odd_class) + test_one_set = torch.utils.data.sampler.SubsetRandomSampler(test_indices_ones) + + train_loader = torch.utils.data.DataLoader(train_set, sampler=train_zero_set, + batch_size=batch_size, shuffle=False, **kwargs) + test_loader = torch.utils.data.DataLoader(test_set, sampler=test_zero_set, + batch_size=batch_size, shuffle=False, **kwargs) + test_loader_abnorm = torch.utils.data.DataLoader(test_set, sampler=test_one_set, + batch_size=batch_size, shuffle=False, **kwargs) + + model = VAE(z=z, input_size=input_size).to(device) + optimizer = optim.Adam(model.parameters(), lr=1e-3) + + vlog = PytorchVisdomLogger(exp_name="vae-fmnist") + elog = PytorchExperimentLogger(base_dir="/home/david/data/logs/mnist_exp_fin", exp_name="fashion-mnist_vae") + + elog.save_config(config, "config") + + for epoch in range(1, n_epochs + 1): + train(epoch, model, optimizer, train_loader, device, scaling, vlog, elog, log_var_std) + + kl_roc, rec_roc, loss_roc, kl_pr, rec_pr, loss_pr = test(model, test_loader, test_loader_abnorm, device, + scaling, vlog, elog, + image_size, batch_size, log_var_std) + + with open(os.path.join(elog.result_dir, "results.json"), "w") as file_: + json.dump({ + "kl_roc": kl_roc, "rec_roc": rec_roc, "loss_roc": loss_roc, + "kl_pr": kl_pr, "rec_pr": rec_pr, "loss_pr": loss_pr, + }, file_, indent=4) + + +if __name__ == '__main__': + scaling = 1 + batch_size = 128 + odd_class = 0 + z = 20 + seed = 123 + log_var_std = 0. + + model_run(scaling, batch_size, odd_class, z, seed, log_var_std) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/__pycache__/__init__.cpython-36.pyc b/models/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..db8f25c Binary files /dev/null and b/models/__pycache__/__init__.cpython-36.pyc differ diff --git a/models/__pycache__/enc_dec.cpython-36.pyc b/models/__pycache__/enc_dec.cpython-36.pyc new file mode 100644 index 0000000..89afd12 Binary files /dev/null and b/models/__pycache__/enc_dec.cpython-36.pyc differ diff --git a/models/enc_dec.py b/models/enc_dec.py new file mode 100644 index 0000000..2a3a381 --- /dev/null +++ b/models/enc_dec.py @@ -0,0 +1,247 @@ +import numpy as np +import torch +import torch.nn as nn + + +class NoOp(nn.Module): + + def __init__(self, *args, **kwargs): + super(NoOp, self).__init__() + + def forward(self, x, *args, **kwargs): + return x + + +class ConvModule(nn.Module): + def __init__(self, in_channels, out_channels, conv_op=nn.Conv2d, conv_params=None, + normalization_op=nn.BatchNorm2d, normalization_params=None, + activation_op=nn.LeakyReLU, activation_params=None): + + super(ConvModule, self).__init__() + + self.conv_params = conv_params + if self.conv_params is None: + self.conv_params = {} + self.activation_params = activation_params + if self.activation_params is None: + self.activation_params = {} + self.normalization_params = normalization_params + if self.normalization_params is None: + self.normalization_params = {} + + self.conv = None + if conv_op is not None and not isinstance(conv_op, str): + self.conv = conv_op(in_channels, out_channels, **self.conv_params) + + self.normalization = None + if normalization_op is not None and not isinstance(normalization_op, str): + self.normalization = normalization_op(out_channels, **self.normalization_params) + + self.activation = None + if activation_op is not None and not isinstance(activation_op, str): + self.activation = activation_op(**self.activation_params) + + def forward(self, input, conv_add_input=None, normalization_add_input=None, activation_add_input=None): + + x = input + + if self.conv is not None: + if conv_add_input is None: + x = self.conv(x) + else: + x = self.conv(x, **conv_add_input) + + if self.normalization is not None: + if normalization_add_input is None: + x = self.normalization(x) + else: + x = self.normalization(x, **normalization_add_input) + + if self.activation is not None: + if activation_add_input is None: + x = self.activation(x) + else: + x = self.activation(x, **activation_add_input) + + return x + + +# Basic Generator +class Generator(nn.Module): + def __init__(self, image_size, z_dim=256, h_size=(256, 128, 64), + upsample_op=nn.ConvTranspose2d, normalization_op=nn.InstanceNorm2d, activation_op=nn.LeakyReLU, + conv_params=None, activation_params=None, block_op=None, block_params=None, to_1x1=True): + + super(Generator, self).__init__() + + if conv_params is None: + conv_params = {} + + n_channels = image_size[0] + img_size = np.array([image_size[1], image_size[2]]) + + if not isinstance(h_size, list) and not isinstance(h_size, tuple): + raise AttributeError("h_size has to be either a list or tuple or an int") + elif len(h_size) < 2: + raise AttributeError("h_size has to contain at least three elements") + else: + h_size_bot = h_size[0] + + # We need to know how many layers we will use at the beginning + img_size_new = img_size // (2 ** len(h_size)) + if np.min(img_size_new) < 2 and z_dim is not None: + raise AttributeError("h_size to long, one image dimension has already perished") + + ### Start block + start_block = [] + + # Z_size random numbers + + if not to_1x1: + kernel_size_start = [min(4, i) for i in img_size_new] + else: + kernel_size_start = img_size_new.tolist() + + if z_dim is not None: + self.start = ConvModule(z_dim, h_size_bot, + conv_op=upsample_op, + conv_params=dict(kernel_size=kernel_size_start, stride=1, padding=0, bias=False, + **conv_params), + normalization_op=normalization_op, + normalization_params={}, + activation_op=activation_op, + activation_params=activation_params + ) + + img_size_new = img_size_new * 2 + else: + self.start = NoOp() + + ### Middle block (Done until we reach ? x image_size/2 x image_size/2) + self.middle_blocks = nn.ModuleList() + + for h_size_top in h_size[1:]: + + if block_op is not None and not isinstance(block_op, str): + self.middle_blocks.append( + block_op(h_size_bot, **block_params) + ) + + self.middle_blocks.append( + ConvModule(h_size_bot, h_size_top, + conv_op=upsample_op, + conv_params=dict(kernel_size=4, stride=2, padding=1, bias=False, **conv_params), + normalization_op=normalization_op, + normalization_params={}, + activation_op=activation_op, + activation_params=activation_params + ) + ) + + h_size_bot = h_size_top + img_size_new = img_size_new * 2 + + ### End block + self.end = ConvModule(h_size_bot, n_channels, + conv_op=upsample_op, + conv_params=dict(kernel_size=4, stride=2, padding=1, bias=False, **conv_params), + normalization_op=None, + activation_op=None) + + def forward(self, inpt, **kwargs): + output = self.start(inpt, **kwargs) + for middle in self.middle_blocks: + output = middle(output, **kwargs) + output = self.end(output, **kwargs) + return output + + +# Basic Encoder +class Encoder(nn.Module): + def __init__(self, image_size, z_dim=256, h_size=(64, 128, 256), + conv_op=nn.Conv2d, normalization_op=nn.InstanceNorm2d, activation_op=nn.LeakyReLU, + conv_params=None, activation_params=None, + block_op=None, block_params=None, + to_1x1=True): + super(Encoder, self).__init__() + + if conv_params is None: + conv_params = {} + + n_channels = image_size[0] + img_size_new = np.array([image_size[1], image_size[2]]) + + if not isinstance(h_size, list) and not isinstance(h_size, tuple): + raise AttributeError("h_size has to be either a list or tuple or an int") + # elif len(h_size) < 2: + # raise AttributeError("h_size has to contain at least three elements") + else: + h_size_bot = h_size[0] + + ### Start block + self.start = ConvModule(n_channels, h_size_bot, + conv_op=conv_op, + conv_params=dict(kernel_size=4, stride=2, padding=1, bias=False, **conv_params), + normalization_op=normalization_op, + normalization_params={}, + activation_op=activation_op, + activation_params=activation_params + ) + img_size_new = img_size_new // 2 + + ### Middle block (Done until we reach ? x 4 x 4) + self.middle_blocks = nn.ModuleList() + + for h_size_top in h_size[1:]: + + if block_op is not None and not isinstance(block_op, str): + self.middle_blocks.append( + block_op(h_size_bot, **block_params) + ) + + self.middle_blocks.append( + ConvModule(h_size_bot, h_size_top, + conv_op=conv_op, + conv_params=dict(kernel_size=4, stride=2, padding=1, bias=False, **conv_params), + normalization_op=normalization_op, + normalization_params={}, + activation_op=activation_op, + activation_params=activation_params + ) + ) + + h_size_bot = h_size_top + img_size_new = img_size_new // 2 + + if np.min(img_size_new) < 2 and z_dim is not None: + raise ("h_size to long, one image dimension has already perished") + + ### End block + if not to_1x1: + kernel_size_end = [min(4, i) for i in img_size_new] + else: + kernel_size_end = img_size_new.tolist() + + if z_dim is not None: + self.end = ConvModule(h_size_bot, z_dim, + conv_op=conv_op, + conv_params=dict(kernel_size=kernel_size_end, stride=1, padding=0, bias=False, + **conv_params), + normalization_op=None, + activation_op=None, + ) + + if to_1x1: + self.output_size = (z_dim, 1, 1) + else: + self.output_size = (z_dim, *[i - (j - 1) for i, j in zip(img_size_new, kernel_size_end)]) + else: + self.end = NoOp() + self.output_size = img_size_new + + def forward(self, inpt, **kwargs): + output = self.start(inpt, **kwargs) + for middle in self.middle_blocks: + output = middle(output, **kwargs) + output = self.end(output, **kwargs) + return output diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/__pycache__/__init__.cpython-36.pyc b/utils/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..8899196 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-36.pyc differ diff --git a/utils/__pycache__/util.cpython-36.pyc b/utils/__pycache__/util.cpython-36.pyc new file mode 100644 index 0000000..8252e35 Binary files /dev/null and b/utils/__pycache__/util.cpython-36.pyc differ diff --git a/utils/preprcess_brain.py b/utils/preprcess_brain.py new file mode 100755 index 0000000..f20254a --- /dev/null +++ b/utils/preprcess_brain.py @@ -0,0 +1,203 @@ +import os +from collections import defaultdict + +import nibabel as nib +import numpy as np +import scipy.ndimage as snd +from medpy.io import load + + +def normalize_img(img, mask): + xp_mean = np.mean(img[mask]) + xp_std = np.std(img[mask]) + 1e-8 + + img[mask] = img[mask] - xp_mean + img[mask] = img[mask] / xp_std + + return img + + +if __name__ == '__main__': + + ###### BRATS17 + start_dir = "" + target_dir = "" + + subj_dict = defaultdict(dict) + mn_list = [] + + for root, dirs, files in os.walk(start_dir): + for f_name in files: + if f_name.endswith("seg.nii.gz") or f_name.endswith("t1.nii.gz") or f_name.endswith("t2.nii.gz"): + print(os.path.join(root, f_name)) + pat_nr = root.split("/")[-1] + + file_name = os.path.join(root, f_name) + + imgs, image_header = load(file_name) + imgs = imgs[None] + + if f_name.endswith("seg.nii.gz"): + imgs = imgs.astype(np.int32) + else: + imgs = imgs.astype(np.float32) + + x = np.where(imgs > 0.00) + + img_nn = imgs[x] + perc_val = np.percentile(img_nn, 0.05) + + xp = np.where(imgs > perc_val) + + xp_mean = np.mean(imgs[xp]) + xp_std = np.std(imgs[xp]) + 1e-8 + + imgs[x] = imgs[x] - xp_mean + imgs[x] = imgs[x] / xp_std + + imgs = imgs.transpose((0, 3, 2, 1)) + imgs = imgs[:, ::-1, :, :] + + imgs = imgs[:, :, + imgs.shape[2] // 2 - 95:imgs.shape[2] // 2 + 95, + imgs.shape[3] // 2 - 77:imgs.shape[3] // 2 + 78] + + imgs = imgs[:, ::-1, :, :] + + imgs_big = np.zeros((1, 155, 190, 165)) + imgs_big[:, :, :, 5:160] = imgs + + print(imgs.shape) + + if f_name.endswith("t1.nii.gz"): + subj_dict[pat_nr]["t1"] = imgs_big[0] + elif f_name.endswith("t2.nii.gz"): + subj_dict[pat_nr]["t2"] = imgs_big[0] + elif f_name.endswith("seg.nii.gz"): + subj_dict[pat_nr]["label"] = imgs_big[0] + + cntr = 1 + for pat_nr, vals in subj_dict.items(): + if "t1" in vals and "t2" in vals and "label" in vals: + pat_nr = "{:05d}".format(cntr) + final_array = np.stack((vals["t1"], vals["t2"], vals["label"])) + target_file = os.path.join(target_dir, pat_nr + ".npy") + np.save(target_file, final_array) + cntr += 1 + print(pat_nr) + + print("Brats17 done.") + exit() + + ###### ISLES 2015 Real + + start_dir = "" + target_dir = "" + + subj_dict = defaultdict(dict) + mn_list = [] + + for root, dirs, files in os.walk(start_dir): + for f_name in files: + if ".MR_T1." in f_name or ".MR_T2." in f_name or "XX.O.OT" in f_name: + print(os.path.join(root, f_name)) + pat_nr = root[len(start_dir):len(start_dir) + 2].replace("/", "_").lower() + + file_name = os.path.join(root, f_name) + + imgs, image_header = load(file_name) + imgs = imgs[None] + + if "XX.O.OT" in f_name: + imgs = imgs.astype(np.int32) + else: + imgs = imgs.astype(np.float32) + + x = np.where(imgs > 0.00) + + img_nn = imgs[x] + perc_val = np.percentile(img_nn, 0.05) + + xp = np.where(imgs > 0.) + + xp_mean = np.mean(imgs[xp]) + xp_std = np.std(imgs[xp]) + 1e-8 + + imgs[x] = imgs[x] - xp_mean + imgs[x] = imgs[x] / xp_std + + imgs = imgs.transpose((0, 3, 2, 1)) + + imgs = imgs[:, :, imgs.shape[2] // 2 - 95:imgs.shape[2] // 2 + 95, + imgs.shape[3] // 2 - 77:imgs.shape[3] // 2 + 78] + + imgs = imgs[:, :, ::-1] + + if ".MR_T1." in f_name: + subj_dict[pat_nr]["t1"] = imgs[0] + elif ".MR_T2." in f_name: + subj_dict[pat_nr]["t2"] = imgs[0] + elif "XX.O.OT" in f_name: + subj_dict[pat_nr]["label"] = imgs[0] + + for pat_nr, vals in subj_dict.items(): + if "t1" in vals and "t2" in vals and "label" in vals: + final_array = np.stack((vals["t1"], vals["t2"], vals["label"])) + target_file = os.path.join(target_dir, pat_nr + ".npy") + np.save(target_file, final_array) + print(pat_nr) + + print("Isles15_siss done.") + exit() + + #### HCP + + start_dir = "" + target_dir = "" + + t1_templ = "T1w_acpc_dc_restore_brain.nii.gz" + t2_templ = "T2w_acpc_dc_restore_brain.nii.gz" + label_templ = "wmparc.nii.gz" + + i = 0 + + for subj in os.listdir(start_dir): + sub_dir = os.path.join(start_dir, subj, "T1w/") + if os.path.isdir(sub_dir): + t1_file = os.path.join(sub_dir, t1_templ) + t2_file = os.path.join(sub_dir, t2_templ) + label_file = os.path.join(sub_dir, label_templ) + + t1_array = load(t1_file)[0] + t2_array = load(t2_file)[0] + label_array = load(label_file)[0] + + t1_array = snd.zoom(t1_array, (0.75, 0.75, 0.75), order=1) + t1_array = t1_array.transpose((2, 1, 0)) + t1_array = t1_array[:, ::-1, :] + t1_array = t1_array[0:165, 15:225, 15:180] + t1_mask = np.where(t1_array != 0) + normalize_img(t1_array, t1_mask) + + t2_array = snd.zoom(t2_array, (0.75, 0.75, 0.75), order=1) + t2_array = t2_array.transpose((2, 1, 0)) + t2_array = t2_array[:, ::-1, :] + t2_array = t2_array[0:165, 15:225, 15:180] + + t2_mask = np.where(t2_array != 0) + normalize_img(t2_array, t2_mask) + + label_array = snd.zoom(label_array, (0.75, 0.75, 0.75), order=0) + label_array = label_array.transpose((2, 1, 0)) + label_array = label_array[:, ::-1, :] + label_array = label_array[0:165, 15:225, 15:180] + + final_array = np.stack((t1_array, t2_array, label_array)) + target_file = os.path.join(target_dir, subj + ".npy") + np.save(target_file, final_array) + + print(i) + i += 1 + + print("HCP Done.") + exit() diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000..0214f01 --- /dev/null +++ b/utils/util.py @@ -0,0 +1,125 @@ +import numpy as np +import torch + + +def f1_score(y_pred, y_label, dims=0, eps=1e-6): + """Calculates the f1 score of a sample (4d, one-hot encoded) """ + + true_positives = torch.sum(y_pred * y_label, dim=dims) + pos_data = torch.sum(y_label, dim=dims) + pos_pred = torch.sum(y_pred, dim=dims) + # false_negatives = torch.sum(torch.sum(y_label, dim=2), dim=2) - true_positives + # false_positives = torch.sum(torch.sum(y_pred, dim=2), dim=2) - true_positives + + precision = true_positives / (pos_pred + eps) + recall = true_positives / (pos_data + eps) + + f1_s = 2 * (precision * recall) / (precision + recall + eps) + + # f1_s = (2*true_positives) / (2* true_positves + false_positives + fase_negatives) + + # dices_scores has shape (batch_size, num_classes) + return f1_s + + +def calc_hard_dice(x, y, thresh): + if torch.is_tensor(x): + x = x.detach().cpu().numpy() + elif isinstance(x, (list, tuple)): + x = np.asarray(x) + + if isinstance(y, (list, tuple)): + y = torch.from_numpy(np.asarray(y)).float() + + x_binary = x > thresh + x_binary = torch.from_numpy(x_binary.astype(int)) + + dice = f1_score(x_binary.float(), y.float()) + + del x, y, x_binary + + dice = dice.item() if torch.is_tensor(dice) else dice + + return dice + + +def find_best_val(x, y, val_fn, val_range=(0, 1), max_steps=4, step=0, max_val=0, max_point=0): + if step == max_steps: + return max_val, max_point + + if val_range[0] == val_range[1]: + val_range = (val_range[0], 1) + + bottom = val_range[0] + top = val_range[1] + center = bottom + (top - bottom) * 0.5 + + q_bottom = bottom + (top - bottom) * 0.25 + q_top = bottom + (top - bottom) * 0.75 + + val_bottom = val_fn(x, y, q_bottom) + val_top = val_fn(x, y, q_top) + + if val_bottom > val_top: + if val_bottom > max_val: + max_val = val_bottom + max_point = q_bottom + return find_best_val(x, y, val_fn, val_range=(bottom, center), step=step + 1, max_steps=max_steps, + max_val=max_val, max_point=max_point) + else: + if val_top > max_val: + max_val = val_bottom + max_point = q_bottom + return find_best_val(x, y, val_fn, val_range=(center, top), step=step + 1, max_steps=max_steps, + max_val=max_val, max_point=max_point) + + +def smooth_tensor(tensor, kernel_size=8, sigma=3, channels=1): + # Set these to whatever you want for your gaussian filter + + if kernel_size % 2 == 0: + kernel_size -= 1 + + # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) + x_cord = torch.arange(kernel_size) + x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) + y_grid = x_grid.t() + xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() + + mean = (kernel_size - 1) / 2. + variance = sigma ** 2. + + # Calculate the 2-dimensional gaussian kernel which is + # the product of two gaussian distributions for two different + # variables (in this case called x and y) + import math + gaussian_kernel = (1. / (2. * math.pi * variance)) * \ + torch.exp( + -torch.sum((xy_grid - mean) ** 2., dim=-1) / \ + (2. * variance) + ) + # Make sure sum of values in gaussian kernel equals 1. + gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) + + # Reshape to 2d depthwise convolutional weight + gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) + gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) + + gaussian_filter = torch.nn.Conv2d(in_channels=channels, out_channels=channels, + kernel_size=kernel_size, groups=channels, bias=False, + padding=kernel_size // 2) + + gaussian_filter.weight.data = gaussian_kernel + gaussian_filter.weight.requires_grad = False + + gaussian_filter.to(tensor.device) + + return gaussian_filter(tensor) + + +def normalize(tensor): + tens_deta = tensor.detach().cpu() + tens_deta -= float(np.min(tens_deta.numpy())) + tens_deta /= float(np.max(tens_deta.numpy())) + + return tens_deta