diff --git a/example_algos/__init__.py b/example_algos/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/example_algos/__init__.py @@ -0,0 +1 @@ + diff --git a/example_algos/algorithms/__init__.py b/example_algos/algorithms/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/example_algos/algorithms/__init__.py @@ -0,0 +1 @@ + diff --git a/example_algos/algorithms/abstract_algorithm.py b/example_algos/algorithms/abstract_algorithm.py new file mode 100644 index 0000000..1bb0456 --- /dev/null +++ b/example_algos/algorithms/abstract_algorithm.py @@ -0,0 +1,19 @@ +from abc import abstractmethod + + +class AbstractMOODAlgorimth: + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def train(self): + pass + + @abstractmethod + def score_sample(self): + pass + + @abstractmethod + def score_pixels(self): + pass diff --git a/example_algos/algorithms/ae_2d.py b/example_algos/algorithms/ae_2d.py new file mode 100644 index 0000000..510ab90 --- /dev/null +++ b/example_algos/algorithms/ae_2d.py @@ -0,0 +1,266 @@ +import os +import sys +import time + +import click +import numpy as np +import torch +import torch.distributions as dist +from torch import optim +from trixi.logger import PytorchExperimentLogger +from trixi.util.config import monkey_patch_fn_args_as_config +from trixi.util.pytorchexperimentstub import PytorchExperimentStub +from trixi.util.pytorchutils import get_smooth_image_gradient +from tqdm import tqdm +from monai.transforms.transforms import Resize +from math import ceil + +from example_algos.data.numpy_dataset import get_numpy2d_dataset, get_numpy3d_dataset +from example_algos.models.aes import AE +from example_algos.util.nifti_io import ni_load, ni_save + + +class AE2D: + @monkey_patch_fn_args_as_config + def __init__( + self, + input_shape, + lr=1e-4, + n_epochs=20, + z_dim=512, + model_feature_map_sizes=(16, 64, 256, 1024), + load_path=None, + log_dir=None, + logger="visdom", + print_every_iter=100, + ): + + self.print_every_iter = print_every_iter + self.n_epochs = n_epochs + self.batch_size = input_shape[0] + self.z_dim = z_dim + self.input_shape = input_shape + self.logger = logger + + log_dict = {} + if logger is not None: + log_dict = { + 0: (logger), + } + self.tx = PytorchExperimentStub(name="delme", base_dir=log_dir, config=fn_args_as_config, loggers=log_dict,) + + cuda_available = torch.cuda.is_available() + self.device = torch.device("cuda" if cuda_available else "cpu") + + self.model = AE(input_size=input_shape[1:], z_dim=z_dim, fmap_sizes=model_feature_map_sizes).to(self.device) + self.optimizer = optim.Adam(self.model.parameters(), lr=lr) + + self.vae_loss_ema = 1 + self.theta = 1 + + if load_path is not None: + PytorchExperimentLogger.load_model_static(self.model, os.path.join(load_path, "ae_final.pth")) + time.sleep(5) + + def train(self): + + train_loader = get_numpy2d_dataset( + base_dir="/fast/moody/brain/train", + num_processes=16, + pin_memory=True, + batch_size=self.batch_size, + mode="train", + ) + val_loader = get_numpy2d_dataset( + base_dir="/fast/moody/brain/train", num_processes=8, pin_memory=True, batch_size=self.batch_size, mode="val" + ) + + for epoch in range(self.n_epochs): + + ### Train + self.model.train() + + train_loss = 0 + print("\nStart epoch ", epoch) + data_loader_ = tqdm(enumerate(train_loader)) + for batch_idx, data in data_loader_: + inpt = data.to(self.device) + + self.optimizer.zero_grad() + inpt_rec = self.model(inpt) + + loss = torch.mean(torch.pow(inpt - inpt_rec, 2)) + loss.backward() + self.optimizer.step() + + train_loss += loss.item() + if batch_idx % self.print_every_iter == 0: + status_str = ( + f"Train Epoch: {epoch} [{batch_idx}/{len(train_loader)} " + f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: " + f"{loss.item() / len(inpt):.6f}" + ) + data_loader_.set_description_str(status_str) + + cnt = epoch * len(train_loader) + batch_idx + self.tx.add_result(loss.item(), name="Train-Loss", tag="Losses", counter=cnt) + + if self.logger is not None: + self.tx.l[0].show_image_grid(inpt, name="Input", image_args={"normalize": True}) + self.tx.l[0].show_image_grid(inpt_rec, name="Reconstruction", image_args={"normalize": True}) + + print(f"====> Epoch: {epoch} Average loss: {train_loss / len(train_loader):.4f}") + + ### Validate + self.model.eval() + + val_loss = 0 + with torch.no_grad(): + data_loader_ = tqdm(enumerate(val_loader)) + data_loader_.set_description_str("Validating") + for i, data in data_loader_: + inpt = data.to(self.device) + inpt_rec = self.model(inpt) + + loss = torch.mean(torch.pow(inpt - inpt_rec, 2)) + val_loss += loss.item() + + self.tx.add_result( + val_loss / len(val_loader), name="Val-Loss", tag="Losses", counter=(epoch + 1) * len(train_loader) + ) + + print(f"====> Epoch: {epoch} Validation loss: {val_loss / len(val_loader):.4f}") + + self.tx.save_model(self.model, "ae_final") + + time.sleep(10) + + def score_sample(self, np_array): + + orig_shape = np_array.shape + to_transforms = torch.nn.Upsample((self.input_shape[2], self.input_shape[3]), mode="bilinear") + from_transforms = torch.nn.Upsample((orig_shape[1], orig_shape[2]), mode="bilinear") + + data_tensor = torch.from_numpy(np_array).float() + data_tensor = to_transforms(data_tensor[None])[0] + slice_scores = [] + + for i in range(ceil(orig_shape[0] / self.batch_size)): + batch = data_tensor[i * self.batch_size : (i + 1) * self.batch_size].unsqueeze(1) + batch = batch.to(self.device) + + with torch.no_grad(): + batch_rec = self.model(batch) + loss = torch.mean(torch.pow(batch - batch_rec, 2), dim=(1, 2, 3)) + + slice_scores += loss.cpu().tolist() + + return np.max(slice_scores) + + def score_pixels(self, np_array): + + orig_shape = np_array.shape + to_transforms = torch.nn.Upsample((self.input_shape[2], self.input_shape[3]), mode="bilinear") + from_transforms = torch.nn.Upsample((orig_shape[1], orig_shape[2]), mode="bilinear") + + data_tensor = torch.from_numpy(np_array).float() + data_tensor = to_transforms(data_tensor[None])[0] + target_tensor = torch.zeros_like(data_tensor) + + for i in range(ceil(orig_shape[0] / self.batch_size)): + batch = data_tensor[i * self.batch_size : (i + 1) * self.batch_size].unsqueeze(1) + batch = batch.to(self.device) + + batch_rec = self.model(batch) + + loss = torch.pow(batch - batch_rec, 2).squeeze() + target_tensor[i * self.batch_size : (i + 1) * self.batch_size] = loss.cpu() + + target_tensor = from_transforms(target_tensor[None])[0] + + return target_tensor.detach().numpy() + + def print(self, *args): + print(*args) + self.tx.print(*args) + + +@click.option("-m", "--mode", default="pixel", type=click.Choice(["pixel", "sample"], case_sensitive=False)) +@click.option( + "-r", "--run", default="train", type=click.Choice(["train", "predict", "test", "all"], case_sensitive=False) +) +@click.option("--target-size", type=click.IntRange(1, 512, clamp=True), default=128) +@click.option("--batch-size", type=click.IntRange(1, 512, clamp=True), default=16) +@click.option("--n-epochs", type=int, default=20) +@click.option("--lr", type=float, default=1e-4) +@click.option("--z-dim", type=int, default=128) +@click.option("-fm", "--fmap-sizes", type=int, multiple=True, default=[16, 64, 256, 1024]) +@click.option("--print-every-iter", type=int, default=100) +@click.option("-l", "--load-path", type=click.Path(exists=True), required=False, default=None) +@click.option("-o", "--log-dir", type=click.Path(exists=True, writable=True), required=False, default=None) +@click.option( + "--logger", type=click.Choice(["visdom", "tensorboard"], case_sensitive=False), required=False, default="visdom" +) +@click.option("-t", "--test-dir", type=click.Path(exists=True), required=False, default=None) +@click.option("-p", "--pred-dir", type=click.Path(exists=True, writable=True), required=False, default=None) +@click.command() +def main( + mode="pixel", + run="train", + target_size=128, + batch_size=16, + n_epochs=20, + lr=1e-4, + z_dim=128, + fmap_sizes=(16, 64, 256, 1024), + print_every_iter=100, + load_path=None, + log_dir=None, + logger="visdom", + test_dir=None, + pred_dir=None, +): + + from scripts.evalresults import eval_dir + + input_shape = (batch_size, 1, target_size, target_size) + + ae_algo = AE2D( + input_shape, + log_dir=log_dir, + n_epochs=n_epochs, + lr=lr, + z_dim=z_dim, + model_feature_map_sizes=fmap_sizes, + print_every_iter=print_every_iter, + load_path=load_path, + logger=logger, + ) + + if run == "train" or run == "all": + ae_algo.train() + + if run == "predict" or run == "all": + + for f_name in os.listdir(test_dir): + ni_file = os.path.join(test_dir, f_name) + ni_data, ni_aff = ni_load(ni_file) + if mode == "pixel": + pixel_scores = ae_algo.score_pixels(ni_data) + ni_save(os.path.join(pred_dir, f_name), pixel_scores, ni_aff) + if mode == "sample": + sample_score = ae_algo.score_sample(ni_data) + with open(os.path.join(pred_dir, f_name + ".txt"), "w") as target_file: + target_file.write(str(sample_score)) + + if run == "test" or run == "all": + + test_dir = test_dir[:-1] if test_dir.endswith("/") else test_dir + score = eval_dir(pred_dir=pred_dir, label_dir=test_dir + f"_label/{mode}", mode=mode) + + print(score) + + +if __name__ == "__main__": + + main() diff --git a/example_algos/algorithms/ae_3d.py b/example_algos/algorithms/ae_3d.py new file mode 100644 index 0000000..e4114c3 --- /dev/null +++ b/example_algos/algorithms/ae_3d.py @@ -0,0 +1,280 @@ +import os +import sys +import time + +import click +import numpy as np +import torch +import torch.distributions as dist +from torch import optim +from trixi.logger import PytorchExperimentLogger +from trixi.util.config import monkey_patch_fn_args_as_config +from trixi.util.pytorchexperimentstub import PytorchExperimentStub +from trixi.util.pytorchutils import get_smooth_image_gradient +from tqdm import tqdm +from monai.transforms.transforms import Resize +from math import ceil + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) +sys.path.append(os.path.join(os.path.dirname(__file__), "../../..")) + +from example_algos.data.numpy_dataset import get_numpy2d_dataset, get_numpy3d_dataset +from example_algos.models.aes import AE +from example_algos.util.nifti_io import ni_load, ni_save + + +class AE3D: + @monkey_patch_fn_args_as_config + def __init__( + self, + input_shape, + lr=1e-4, + n_epochs=20, + z_dim=512, + model_feature_map_sizes=(16, 64, 256, 1024), + load_path=None, + log_dir=None, + logger="visdom", + print_every_iter=100, + ): + + self.print_every_iter = print_every_iter + self.n_epochs = n_epochs + self.batch_size = input_shape[0] + self.z_dim = z_dim + self.input_shape = input_shape + self.logger = logger + + log_dict = {} + if logger is not None: + log_dict = { + 0: (logger), + } + self.tx = PytorchExperimentStub(name="delme", base_dir=log_dir, config=fn_args_as_config, loggers=log_dict,) + + cuda_available = torch.cuda.is_available() + self.device = torch.device("cuda" if cuda_available else "cpu") + + self.model = AE( + input_size=input_shape[1:], + z_dim=z_dim, + fmap_sizes=model_feature_map_sizes, + conv_op=torch.nn.Conv3d, + tconv_op=torch.nn.ConvTranspose3d, + ).to(self.device) + self.optimizer = optim.Adam(self.model.parameters(), lr=lr) + + self.vae_loss_ema = 1 + self.theta = 1 + + if load_path is not None: + PytorchExperimentLogger.load_model_static(self.model, os.path.join(load_path, "ae_final.pth")) + time.sleep(5) + + def train(self): + + train_loader = get_numpy3d_dataset( + base_dir="/fast/moody/brain/train", + num_processes=16, + pin_memory=False, + batch_size=self.batch_size, + mode="train", + ) + val_loader = get_numpy3d_dataset( + base_dir="/fast/moody/brain/train", + num_processes=8, + pin_memory=False, + batch_size=self.batch_size, + mode="val", + ) + + for epoch in range(self.n_epochs): + + ### Train + self.model.train() + + train_loss = 0 + print("\nStart epoch ", epoch) + data_loader_ = tqdm(enumerate(train_loader)) + for batch_idx, data in data_loader_: + inpt = data.to(self.device) + + self.optimizer.zero_grad() + inpt_rec = self.model(inpt) + + loss = torch.mean(torch.pow(inpt - inpt_rec, 2)) + loss.backward() + self.optimizer.step() + + train_loss += loss.item() + if batch_idx % self.print_every_iter == 0: + status_str = ( + f"Train Epoch: {epoch} [{batch_idx}/{len(train_loader)} " + f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: " + f"{loss.item() / len(inpt):.6f}" + ) + data_loader_.set_description_str(status_str) + + cnt = epoch * len(train_loader) + batch_idx + self.tx.add_result(loss.item(), name="Train-Loss", tag="Losses", counter=cnt) + + if self.logger is not None: + self.tx.l[0].show_image_grid( + inpt[:, :, self.input_shape[2] // 2], name="Input", image_args={"normalize": True} + ) + self.tx.l[0].show_image_grid( + inpt_rec[:, :, self.input_shape[2] // 2], + name="Reconstruction", + image_args={"normalize": True}, + ) + + print(f"====> Epoch: {epoch} Average loss: {train_loss / len(train_loader):.4f}") + + ### Validate + self.model.eval() + + val_loss = 0 + with torch.no_grad(): + data_loader_ = tqdm(enumerate(val_loader)) + data_loader_.set_description_str("Validating") + for i, data in data_loader_: + inpt = data.to(self.device) + inpt_rec = self.model(inpt) + + loss = torch.mean(torch.pow(inpt - inpt_rec, 2)) + val_loss += loss.item() + + self.tx.add_result( + val_loss / len(val_loader), name="Val-Loss", tag="Losses", counter=(epoch + 1) * len(train_loader) + ) + + print(f"====> Epoch: {epoch} Validation loss: {val_loss / len(val_loader):.4f}") + + self.tx.save_model(self.model, "ae_final") + + time.sleep(10) + + def score_sample(self, np_array): + + orig_shape = np_array.shape + to_transforms = torch.nn.Upsample((self.input_shape[2], self.input_shape[3], self.input_shape[4])) + from_transforms = torch.nn.Upsample((orig_shape[0], orig_shape[1], orig_shape[2])) + + data_tensor = torch.from_numpy(np_array).float() + data_tensor = to_transforms(data_tensor[None][None]) + + with torch.no_grad(): + inpt = data_tensor.to(self.device) + inpt_rec = self.model(inpt) + + loss = torch.mean(torch.pow(inpt - inpt_rec, 2)) + + score = loss.cpu().item() + + return score + + def score_pixels(self, np_array): + + orig_shape = np_array.shape + to_transforms = torch.nn.Upsample((self.input_shape[2], self.input_shape[3], self.input_shape[4])) + from_transforms = torch.nn.Upsample((orig_shape[0], orig_shape[1], orig_shape[2])) + + data_tensor = torch.from_numpy(np_array).float() + data_tensor = to_transforms(data_tensor[None][None]) + + with torch.no_grad(): + inpt = data_tensor.to(self.device) + inpt_rec = self.model(inpt) + + loss = torch.pow(inpt - inpt_rec, 2) + + target_tensor = loss.cpu().detach() + target_tensor = from_transforms(target_tensor)[0][0] + + return target_tensor.detach().numpy() + + def print(self, *args): + print(*args) + self.tx.print(*args) + + +@click.option("-m", "--mode", default="pixel", type=click.Choice(["pixel", "sample"], case_sensitive=False)) +@click.option( + "-r", "--run", default="train", type=click.Choice(["train", "predict", "test", "all"], case_sensitive=False) +) +@click.option("--target-size", type=click.IntRange(1, 512, clamp=True), default=128) +@click.option("--batch-size", type=click.IntRange(1, 512, clamp=True), default=4) +@click.option("--n-epochs", type=int, default=20) +@click.option("--lr", type=float, default=1e-4) +@click.option("--z-dim", type=int, default=128) +@click.option("-fm", "--fmap-sizes", type=int, multiple=True, default=[16, 64, 256, 1024]) +@click.option("--print-every-iter", type=int, default=10) +@click.option("-l", "--load-path", type=click.Path(exists=True), required=False, default=None) +@click.option("-o", "--log-dir", type=click.Path(exists=True, writable=True), required=False, default=None) +@click.option( + "--logger", type=click.Choice(["visdom", "tensorboard"], case_sensitive=False), required=False, default=None +) +@click.option("-t", "--test-dir", type=click.Path(exists=True), required=False, default=None) +@click.option("-p", "--pred-dir", type=click.Path(exists=True, writable=True), required=False, default=None) +@click.command() +def main( + mode="pixel", + run="train", + target_size=128, + batch_size=16, + n_epochs=20, + lr=1e-4, + z_dim=128, + fmap_sizes=(16, 64, 256, 1024), + print_every_iter=100, + load_path=None, + log_dir=None, + logger="visdom", + test_dir=None, + pred_dir=None, +): + + from scripts.evalresults import eval_dir + + input_shape = (batch_size, 1, target_size, target_size, target_size) + + ae_algo = AE3D( + input_shape, + log_dir=log_dir, + n_epochs=n_epochs, + lr=lr, + z_dim=z_dim, + model_feature_map_sizes=fmap_sizes, + print_every_iter=print_every_iter, + load_path=load_path, + logger=logger, + ) + + if run == "train" or run == "all": + ae_algo.train() + + if run == "predict" or run == "all": + + for f_name in os.listdir(test_dir): + ni_file = os.path.join(test_dir, f_name) + ni_data, ni_aff = ni_load(ni_file) + if mode == "pixel": + pixel_scores = ae_algo.score_pixels(ni_data) + ni_save(os.path.join(pred_dir, f_name), pixel_scores, ni_aff) + if mode == "sample": + sample_score = ae_algo.score_sample(ni_data) + with open(os.path.join(pred_dir, f_name + ".txt"), "w") as target_file: + target_file.write(str(sample_score)) + + if run == "test" or run == "all": + + test_dir = test_dir[:-1] if test_dir.endswith("/") else test_dir + score = eval_dir(pred_dir=pred_dir, label_dir=test_dir + f"_label/{mode}", mode=mode) + + print(score) + + +if __name__ == "__main__": + + main() diff --git a/example_algos/data/__init__.py b/example_algos/data/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/example_algos/data/__init__.py @@ -0,0 +1 @@ + diff --git a/example_algos/data/numpy_dataset.py b/example_algos/data/numpy_dataset.py new file mode 100644 index 0000000..3640a80 --- /dev/null +++ b/example_algos/data/numpy_dataset.py @@ -0,0 +1,367 @@ +import fnmatch +import os +import random +import shutil +import string +import time +from abc import abstractmethod +from collections import defaultdict +from time import sleep + +import numpy as np +from monai.transforms import AddChannel, Compose, Resize, ScaleIntensity, ToTensor +from torch.utils.data import DataLoader, Dataset + + +def get_transforms_2d(target_size=128): + """Returns a Transform which resizes 2D samples (1xHxW) to a target_size (1 x target_size x target_size) + and then converts them to a pytorch tensor. + + Args: + target_size (int, optional): [New spatial dimension of the input data]. Defaults to 128. + + Returns: + [Transform] + """ + + transforms = Compose([Resize((target_size, target_size)), ToTensor()]) + return transforms + + +def get_transforms_3d(target_size=128): + """Returns a Transform which resizes 3D samples (1xZxYxX) to a target_size (1 x target_size x target_size x target_size) + and then converts them to a pytorch tensor. + + Args: + target_size (int, optional): [New spatial dimension of the input data]. Defaults to 128. + + Returns: + [Transform] + """ + transforms = Compose([Resize((target_size, target_size, target_size)), ToTensor()]) + return transforms + + +def get_numpy2d_dataset( + base_dir, + mode="train", + batch_size=16, + n_items=None, + pin_memory=False, + num_processes=1, + drop_last=False, + target_size=128, + file_pattern="*data.npy", + do_reshuffle=True, + slice_offset=0, + caching=True, +): + """Returns a Pytorch data loader which loads a Numpy2dDataSet, i.e. 2D slices from a dir of 3D Numpy arrays. + + Args: + base_dir ([str]): [Directory in which the npy files are.] + mode (str, optional): [train or val, loads the first 90% for train and 10% for val]. Defaults to "train". + batch_size (int, optional): [See pytorch DataLoader]. Defaults to 16. + n_items ([int], optional): [Number of items in on interation, by default number of files in the loaded set + but can be smaller (uses subset) or larger (uses file multiple times)]. Defaults to None. + pin_memory (bool, optional): [See pytorch DataLoader]. Defaults to False. + num_processes (int, optional): [See pytorch DataLoader]. Defaults to 1. + drop_last (bool, optional): [See pytorch DataLoader]. Defaults to False. + target_size (int, optional): [New spatial dimension of to which the input data will be transformed]. Defaults to 128. + file_pattern (str, optional): [File pattern of files to load from the base_dir]. Defaults to "*data.npy". + do_reshuffle (bool, optional): [See pytorch DataLoader]. Defaults to True. + slice_offset (int, optinal): [Offset for the first dimension to skip the first/last n slices]. Defaults to 0. + caching (bool, optinal): [If True saves the data set list to a file in the base_dir + so files don't have to be indexed again and can be more quickly loaded using the cache file]. Defaults to True. + + Returns: + [DataLoader]: Pytorch data loader which loads a Numpy2dDataSet + """ + + transforms = get_transforms_2d(target_size=target_size) + + data_set = Numpy2dDataSet( + base_dir=base_dir, + mode=mode, + n_items=n_items, + file_pattern=file_pattern, + slice_offset=slice_offset, + caching=caching, + transforms=transforms, + ) + + data_loader = DataLoader( + data_set, + batch_size=batch_size, + shuffle=do_reshuffle, + num_workers=num_processes, + pin_memory=pin_memory, + drop_last=drop_last, + ) + return data_loader + + +def get_numpy3d_dataset( + base_dir, + mode="train", + batch_size=16, + n_items=None, + pin_memory=False, + num_processes=1, + drop_last=False, + target_size=128, + file_pattern="*data.npy", + do_reshuffle=True, +): + """Returns a Pytorch data loader which loads a Numpy3dDataSet, i.e. 3D Numpy arrays from a directory. + + Args: + base_dir ([str]): [Directory in which the npy files are.] + mode (str, optional): [train or val, loads the first 90% for train and 10% for val]. Defaults to "train". + batch_size (int, optional): [See pytorch DataLoader]. Defaults to 16. + n_items ([int], optional): [Number of items in on interation, by default number of files in the loaded set + but can be smaller (uses subset) or larger (uses file multiple times)]. Defaults to None. + pin_memory (bool, optional): [See pytorch DataLoader]. Defaults to False. + num_processes (int, optional): [See pytorch DataLoader]. Defaults to 1. + drop_last (bool, optional): [See pytorch DataLoader]. Defaults to False. + target_size (int, optional): [New spatial dimension of to which the input data will be transformed]. Defaults to 128. + file_pattern (str, optional): [File pattern of files to load from the base_dir]. Defaults to "*data.npy". + do_reshuffle (bool, optional): [See pytorch DataLoader]. Defaults to True. + + Returns: + [DataLoader]: Pytorch data loader which loads a Numpy3dDataSet + """ + + transforms = get_transforms_3d(target_size=target_size) + + data_set = Numpy3dDataSet( + base_dir=base_dir, mode=mode, n_items=n_items, file_pattern=file_pattern, transforms=transforms, + ) + + data_loader = DataLoader( + data_set, + batch_size=batch_size, + shuffle=do_reshuffle, + num_workers=num_processes, + pin_memory=pin_memory, + drop_last=drop_last, + ) + return data_loader + + +class Numpy2dDataSet(Dataset): + def __init__( + self, + base_dir, + mode="train", + n_items=None, + file_pattern="*data.npy", + slice_offset=0, + caching=True, + transforms=None, + ): + """Dataset which loads 2D slices from a dir of 3D Numpy arrays. + + Args: + base_dir ([str]): [Directory in which the npy files are.] + mode (str, optional): [train or val, loads the first 90% for train and 10% for val]. Defaults to "train". + n_items ([type], optional): [Number of items in on interation, by default number of files in the loaded set + but can be smaller (uses subset) or larger (uses file multiple times)]. Defaults to None. + file_pattern (str, optional): [File pattern of files to load from the base_dir]. Defaults to "*data.npy". + slice_offset (int, optinal): [Offset for the first dimension to skip the first/last n slices]. Defaults to 0. + caching (bool, optinal): [If True saves the data set list to a file in the base_dir + so files don't have to be indexed again and can be more quickly loaded using the cache file]. Defaults to True. + transforms ([type], optional): [Transformation to do after loading the dataset -> pytorch data transforms]. Defaults to Non + """ + + self.base_dir = base_dir + self.items = self.load_dataset( + base_dir, mode=mode, pattern=file_pattern, slice_offset=slice_offset, caching=caching + ) + self.transforms = transforms + + self.data_len = len(self.items) + if n_items is None: + self.n_items = self.data_len + else: + self.n_items = int(n_items) + + self.reshuffle() + + def reshuffle(self): + print("Reshuffle...") + random.shuffle(self.items) + self.items.sort(key=lambda x: x[0]) + + def __len__(self): + return self.n_items + + def __getitem__(self, item): + + if item >= self.n_items: + raise StopIteration() + + idx = item % self.data_len + data_smpl = self.get_data_by_idx(idx) + + if self.transforms is not None: + data_smpl = self.transforms(data_smpl) + + return data_smpl + + def get_data_by_idx(self, idx): + """Returns a data sample for a given index , i.e. a Np-Array slice + + Args: + idx ([int]): [Index of the data sample] + + Returns: + [np.ndarray]: [3-D Numpy array 1xHxW] + """ + + slice_info = self.items[idx] + fn_name = slice_info[1] + slice_idx = slice_info[2] + + numpy_array = np.load(fn_name, mmap_mode="r") + numpy_slice = numpy_array[slice_idx : slice_idx + 1] + + del numpy_array + + return numpy_slice.astype(np.float32) + + def load_dataset(self, base_dir, mode="train", pattern="*data.npy", slice_offset=0, caching=True): + """Indexes all files in the given directory and returns a list of 2-D slices (file_index, npy_file, slice_index_for_np_file) + (so they can be loaded with get_data_by_idx) + + Args: + base_dir ([str]): [Directory in which the npy files are.] + mode (str, optional): [train or val, loads the first 90% for train and 10% for val]. Defaults to "train". + file_pattern (str, optional): [File pattern of files to load from the base_dir]. Defaults to "*data.npy". + slice_offset (int, optinal): [Offset for the first dimension to skip the first/last n slices]. Defaults to 0. + caching (bool, optinal): [If True saves the data set list to a file in the base_dir + so files don't have to be indexed again and can be more quickly loaded using the cache file]. Defaults to True. + + Returns: + [list]: [List of (Numpy_file X slieces in the file) which should be used in the dataset] + """ + slices = [] + + if caching: + cache_file = os.path.join(base_dir, f"cache_file_{mode}_{slice_offset}.lst") + if os.path.exists(cache_file): + slices = np.load(cache_file) + return slices + + all_files = os.listdir(base_dir) + npy_files = fnmatch.filter(all_files, pattern) + n_files = len(npy_files) + + if mode == "train": + load_files = npy_files[: int(0.9 * n_files)] + elif mode == "val": + load_files = npy_files[int(0.9 * n_files) :] + else: + load_files = [] + + for i, filename in enumerate(sorted(load_files)): + npy_file = os.path.join(base_dir, filename) + numpy_array = np.load(npy_file, mmap_mode="r") + + file_len = numpy_array.shape[1] + + slices.extend([(i, npy_file, j) for j in range(slice_offset, file_len - slice_offset)]) + + if caching: + np.save(cache_file, slices) + + return slices + + +class Numpy3dDataSet(Dataset): + def __init__( + self, base_dir, mode="train", n_items=None, file_pattern="*data.npy", transforms=None, + ): + """A Datasets that loads 3D Numpy array from a (flat) directory + + Args: + base_dir ([str]): [Directory in which the npy files are.] + mode (str, optional): [train or val, loads the first 90% for train and 10% for val]. Defaults to "train". + n_items ([type], optional): [Number of items in on interation, by default number of files in the loaded set + but can be smaller (uses subset) or larger (uses file multiple times)]. Defaults to None. + file_pattern (str, optional): [File pattern of files to load from the base_dir]. Defaults to "*data.npy". + transforms ([type], optional): [Transformation to do after loading the dataset -> pytorch data transforms]. Defaults to None. + """ + + self.base_dir = base_dir + self.items = self.load_dataset(base_dir, mode=mode, pattern=file_pattern) + self.transforms = transforms + + self.data_len = len(self.items) + if n_items is None: + self.n_items = self.data_len + else: + self.n_items = int(n_items) + + def reshuffle(self): + print("Reshuffle...") + random.shuffle(self.items) + + def __len__(self): + return self.n_items + + def __getitem__(self, item): + + if item >= self.n_items: + raise StopIteration() + + idx = item % self.data_len + + data_smpl = self.get_data_by_idx(idx) + + if self.transforms is not None: + data_smpl = self.transforms(data_smpl) + + return data_smpl + + def get_data_by_idx(self, idx): + """Returns a data sample for a given index + + Args: + idx ([int]): [Index of the data sample] + + Returns: + [np.ndarray]: [4-D Numpy array 1xCxHxW] + """ + + fn_name = self.items[idx] + numpy_array = np.load(fn_name, mmap_mode="r") + return numpy_array.astype(np.float32)[None] + + @staticmethod + def load_dataset(base_dir, mode="train", pattern="*data.npy"): + """Indexes all files in the given directory (so they can be loaded with get_data_by_idx) + + Args: + base_dir ([str]): [Directory in which the npy files are.] + mode (str, optional): [train or val, loads the first 90% for train and 10% for val]. Defaults to "train". + file_pattern (str, optional): [File pattern of files to load from the base_dir]. Defaults to "*data.npy". + + Returns: + [list]: [List of files which should be used in the dataset] + """ + + all_files = os.listdir(base_dir) + npy_files = fnmatch.filter(all_files, pattern) + n_files = len(npy_files) + + if mode == "train": + load_files = npy_files[: int(0.9 * n_files)] + elif mode == "val": + load_files = npy_files[int(0.9 * n_files) :] + else: + load_files = [] + + load_files = [os.path.join(base_dir, fn) for fn in load_files] + + return load_files diff --git a/example_algos/data/preprocess.py b/example_algos/data/preprocess.py new file mode 100644 index 0000000..998365b --- /dev/null +++ b/example_algos/data/preprocess.py @@ -0,0 +1,45 @@ +import argparse +import os + +import nibabel as nib +import numpy as np +from tqdm import tqdm + + +def nifti_to_numpy(input_folder: str, output_folder: str): + """Converts all nifti files in a input folder to numpy and saves the data and affine matrix into the output folder + + Args: + input_folder (str): Folder to read the nifti files from + output_folder (str): Folder to write the numpy arrays to + """ + + for fname in tqdm(sorted(os.listdir(input_folder))): + + if not fname.endswith("nii.gz"): + continue + + n_file = os.path.join(input_folder, fname) + nifti = nib.load(n_file) + + np_data = nifti.get_fdata() + np_affine = nifti.affine + + f_basename = fname.split(".")[0] + + np.save(os.path.join(output_folder, f_basename + "_data.npy"), np_data.astype(np.float16)) + np.save(os.path.join(output_folder, f_basename + "_aff.npy"), np_affine) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input_dir", required=True, type=str) + parser.add_argument("-o", "--output_dir", required=False, type=str) + + args = parser.parse_args() + + input_dir = args.input_dir + output_dir = args.output_dir + + nifti_to_numpy(input_dir, output_dir) diff --git a/example_algos/models/__init__.py b/example_algos/models/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/example_algos/models/__init__.py @@ -0,0 +1 @@ + diff --git a/example_algos/models/aes.py b/example_algos/models/aes.py new file mode 100644 index 0000000..ddff124 --- /dev/null +++ b/example_algos/models/aes.py @@ -0,0 +1,239 @@ +import numpy as np +import torch +import torch.distributions as dist + +from example_algos.models.nets import BasicEncoder, BasicGenerator + + +class VAE(torch.nn.Module): + def __init__( + self, + input_size, + z_dim=256, + fmap_sizes=(16, 64, 256, 1024), + to_1x1=True, + conv_op=torch.nn.Conv2d, + conv_params=None, + tconv_op=torch.nn.ConvTranspose2d, + tconv_params=None, + normalization_op=None, + normalization_params=None, + activation_op=torch.nn.LeakyReLU, + activation_params=None, + block_op=None, + block_params=None, + *args, + **kwargs + ): + """Basic VAE build up of a symetric BasicEncoder (Encoder) and BasicGenerator (Decoder) + + Args: + input_size ((int, int, int): Size of the input in format CxHxW): + z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). Defaults to 256 + fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each + int defines the number of feature maps in the layer]. Defaults to (16, 64, 256, 1024). + to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) + or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ]. + Defaults to True. + conv_op ([torch.nn.Module], optional): [Convolutioon operation used in the encoder to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d. + conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). + tconv_op ([torch.nn.Module], optional): [Upsampling/ Transposed Conv operation used in the decoder to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d. + tconv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). + normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. + normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. + activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. + activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. + block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. + block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. + """ + + super(VAE, self).__init__() + + input_size_enc = list(input_size) + input_size_dec = list(input_size) + + self.enc = BasicEncoder( + input_size=input_size_enc, + fmap_sizes=fmap_sizes, + z_dim=z_dim * 2, + conv_op=conv_op, + conv_params=conv_params, + normalization_op=normalization_op, + normalization_params=normalization_params, + activation_op=activation_op, + activation_params=activation_params, + block_op=block_op, + block_params=block_params, + to_1x1=to_1x1, + ) + self.dec = BasicGenerator( + input_size=input_size_dec, + fmap_sizes=fmap_sizes[::-1], + z_dim=z_dim, + upsample_op=tconv_op, + conv_params=tconv_params, + normalization_op=normalization_op, + normalization_params=normalization_params, + activation_op=activation_op, + activation_params=activation_params, + block_op=block_op, + block_params=block_params, + to_1x1=to_1x1, + ) + + self.hidden_size = self.enc.output_size + + def forward(self, inpt, sample=True, no_dist=False, **kwargs): + y1 = self.enc(inpt, **kwargs) + + mu, log_std = torch.chunk(y1, 2, dim=1) + std = torch.exp(log_std) + z_dist = dist.Normal(mu, std) + if sample: + z_sample = z_dist.rsample() + else: + z_sample = mu + + x_rec = self.dec(z_sample) + + if no_dist: + return x_rec + else: + return x_rec, z_dist + + def encode(self, inpt, **kwargs): + """Encodes a sample and returns the paramters for the approx inference dist. (Normal) + + Args: + inpt ([tensor]): The input to encode + + Returns: + mu : The mean used to parameterized a Normal distribution + std: The standard deviation used to parameterized a Normal distribution + """ + enc = self.enc(inpt, **kwargs) + mu, log_std = torch.chunk(enc, 2, dim=1) + std = torch.exp(log_std) + return mu, std + + def decode(self, inpt, **kwargs): + """Decodes a latent space sample, used the generative model (decode = mu_{gen}(z) as used in p(x|z) = N(x | mu_{gen}(z), 1) ). + + Args: + inpt ([type]): A sample from the latent space to decode + + Returns: + [type]: [description] + """ + x_rec = self.dec(inpt, **kwargs) + return x_rec + + +class AE(torch.nn.Module): + def __init__( + self, + input_size, + z_dim=1024, + fmap_sizes=(16, 64, 256, 1024), + to_1x1=True, + conv_op=torch.nn.Conv2d, + conv_params=None, + tconv_op=torch.nn.ConvTranspose2d, + tconv_params=None, + normalization_op=None, + normalization_params=None, + activation_op=torch.nn.LeakyReLU, + activation_params=None, + block_op=None, + block_params=None, + *args, + **kwargs + ): + """Basic AE build up of a symetric BasicEncoder (Encoder) and BasicGenerator (Decoder) + + Args: + input_size ((int, int, int): Size of the input in format CxHxW): + z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). Defaults to 256 + fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each + int defines the number of feature maps in the layer]. Defaults to (16, 64, 256, 1024). + to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) + or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ]. + Defaults to True. + conv_op ([torch.nn.Module], optional): [Convolutioon operation used in the encoder to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d. + conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). + tconv_op ([torch.nn.Module], optional): [Upsampling/ Transposed Conv operation used in the decoder to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d. + tconv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). + normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. + normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. + activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. + activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. + block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. + block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. + """ + super(AE, self).__init__() + + input_size_enc = list(input_size) + input_size_dec = list(input_size) + + self.enc = BasicEncoder( + input_size=input_size_enc, + fmap_sizes=fmap_sizes, + z_dim=z_dim, + conv_op=conv_op, + conv_params=conv_params, + normalization_op=normalization_op, + normalization_params=normalization_params, + activation_op=activation_op, + activation_params=activation_params, + block_op=block_op, + block_params=block_params, + to_1x1=to_1x1, + ) + self.dec = BasicGenerator( + input_size=input_size_dec, + fmap_sizes=fmap_sizes[::-1], + z_dim=z_dim, + upsample_op=tconv_op, + conv_params=tconv_params, + normalization_op=normalization_op, + normalization_params=normalization_params, + activation_op=activation_op, + activation_params=activation_params, + block_op=block_op, + block_params=block_params, + to_1x1=to_1x1, + ) + + self.hidden_size = self.enc.output_size + + def forward(self, inpt, **kwargs): + + y1 = self.enc(inpt, **kwargs) + + x_rec = self.dec(y1) + + return x_rec + + def encode(self, inpt, **kwargs): + """Encodes a input sample to a latent space sample + + Args: + inpt ([tensor]): Input sample + + Returns: + enc: Encoded input sample in the latent space + """ + enc = self.enc(inpt, **kwargs) + return enc + + def decode(self, inpt, **kwargs): + """Decodes a latent space sample back to the input space + + Args: + inpt ([tensor]): [Latent space sample] + + Returns: + [rec]: [Encoded latent sample back in the input space] + """ + rec = self.dec(inpt, **kwargs) + return rec diff --git a/example_algos/models/nets.py b/example_algos/models/nets.py new file mode 100644 index 0000000..9c9faf7 --- /dev/null +++ b/example_algos/models/nets.py @@ -0,0 +1,456 @@ +import warnings + +import numpy as np +import torch +import torch.nn as nn + + +class NoOp(nn.Module): + def __init__(self, *args, **kwargs): + """NoOp Pytorch Module. + Forwards the given input as is. + """ + super(NoOp, self).__init__() + + def forward(self, x, *args, **kwargs): + return x + + +class ConvModule(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + conv_op=nn.Conv2d, + conv_params=None, + normalization_op=None, + normalization_params=None, + activation_op=nn.LeakyReLU, + activation_params=None, + ): + """Basic Conv Pytorch Conv Module + Has can have a Conv Op, a Normlization Op and a Non Linearity: + x = conv(x) + x = some_norm(x) + x = nonlin(x) + + Args: + in_channels ([int]): [Number on input channels/ feature maps] + out_channels ([int]): [Number of ouput channels/ feature maps] + conv_op ([torch.nn.Module], optional): [Conv operation]. Defaults to nn.Conv2d. + conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None. + normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...)]. Defaults to None. + normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. + activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...)]. Defaults to nn.LeakyReLU. + activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. + """ + + super(ConvModule, self).__init__() + + self.conv_params = conv_params + if self.conv_params is None: + self.conv_params = {} + self.activation_params = activation_params + if self.activation_params is None: + self.activation_params = {} + self.normalization_params = normalization_params + if self.normalization_params is None: + self.normalization_params = {} + + self.conv = None + if conv_op is not None and not isinstance(conv_op, str): + self.conv = conv_op(in_channels, out_channels, **self.conv_params) + + self.normalization = None + if normalization_op is not None and not isinstance(normalization_op, str): + self.normalization = normalization_op(out_channels, **self.normalization_params) + + self.activation = None + if activation_op is not None and not isinstance(activation_op, str): + self.activation = activation_op(**self.activation_params) + + def forward(self, input, conv_add_input=None, normalization_add_input=None, activation_add_input=None): + + x = input + + if self.conv is not None: + if conv_add_input is None: + x = self.conv(x) + else: + x = self.conv(x, **conv_add_input) + + if self.normalization is not None: + if normalization_add_input is None: + x = self.normalization(x) + else: + x = self.normalization(x, **normalization_add_input) + + if self.activation is not None: + if activation_add_input is None: + x = self.activation(x) + else: + x = self.activation(x, **activation_add_input) + + # nn.functional.dropout(x, p=0.95, training=True) + + return x + + +class ConvBlock(nn.Module): + def __init__( + self, + n_convs: int, + n_featmaps: int, + conv_op=nn.Conv2d, + conv_params=None, + normalization_op=nn.BatchNorm2d, + normalization_params=None, + activation_op=nn.LeakyReLU, + activation_params=None, + ): + """Basic Conv block with repeated conv, build up from repeated @ConvModules (with same/fixed feature map size) + + Args: + n_convs ([type]): [Number of convolutions] + n_featmaps ([type]): [Feature map size of the conv] + conv_op ([torch.nn.Module], optional): [Convulioton operation -> see ConvModule ]. Defaults to nn.Conv2d. + conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None. + normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. + normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. + activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. + activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. + """ + + super(ConvBlock, self).__init__() + + self.n_featmaps = n_featmaps + self.n_convs = n_convs + self.conv_params = conv_params + if self.conv_params is None: + self.conv_params = {} + + self.conv_list = nn.ModuleList() + + for i in range(self.n_convs): + conv_layer = ConvModule( + n_featmaps, + n_featmaps, + conv_op=conv_op, + conv_params=conv_params, + normalization_op=normalization_op, + normalization_params=normalization_params, + activation_op=activation_op, + activation_params=activation_params, + ) + self.conv_list.append(conv_layer) + + def forward(self, input, **frwd_params): + x = input + for conv_layer in self.conv_list: + x = conv_layer(x) + + return x + + +class ResBlock(nn.Module): + def __init__( + self, + n_convs, + n_featmaps, + conv_op=nn.Conv2d, + conv_params=None, + normalization_op=nn.BatchNorm2d, + normalization_params=None, + activation_op=nn.LeakyReLU, + activation_params=None, + ): + """Basic Conv block with repeated conv, build up from repeated @ConvModules (with same/fixed feature map size) and a skip/ residual connection: + x = input + x = conv_block(x) + out = x + input + + Args: + n_convs ([type]): [Number of convolutions in the conv block] + n_featmaps ([type]): [Feature map size of the conv block] + conv_op ([torch.nn.Module], optional): [Convulioton operation -> see ConvModule ]. Defaults to nn.Conv2d. + conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None. + normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. + normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. + activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. + activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. + """ + super(ResBlock, self).__init__() + + self.n_featmaps = n_featmaps + self.n_convs = n_convs + self.conv_params = conv_params + if self.conv_params is None: + self.conv_params = {} + + self.conv_block = ConvBlock( + n_featmaps, + n_convs, + conv_op=conv_op, + conv_params=conv_params, + normalization_op=normalization_op, + normalization_params=normalization_params, + activation_op=activation_op, + activation_params=activation_params, + ) + + def forward(self, input, **frwd_params): + x = input + x = self.conv_block(x) + + out = x + input + + return out + + +# Basic Generator +class BasicGenerator(nn.Module): + def __init__( + self, + input_size, + z_dim=256, + fmap_sizes=(256, 128, 64), + upsample_op=nn.ConvTranspose2d, + conv_params=None, + normalization_op=NoOp, + normalization_params=None, + activation_op=nn.LeakyReLU, + activation_params=None, + block_op=NoOp, + block_params=None, + to_1x1=True, + ): + """Basic configureable Generator/ Decoder. + Allows for mutilple "feature-map" levels defined by the feature map size, where for each feature map size a conv operation + optional conv block is used. + + Args: + input_size ((int, int, int): Size of the input in format CxHxW): + z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). + fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each + int defines the number of feature maps in the layer]. Defaults to (256, 128, 64). + upsample_op ([torch.nn.Module], optional): [Upsampling operation used, to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d. + conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). + normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. + normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. + activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. + activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. + block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. + block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. + to_1x1 (bool, optional): [If Latent dimesion is a z_dim x 1 x 1 vector (True) or if allows spatial resolution not to be 1x1 (z_dim x H x W) (False) ]. Defaults to True. + """ + + super(BasicGenerator, self).__init__() + + if conv_params is None: + conv_params = dict(kernel_size=4, stride=2, padding=1, bias=False) + if block_op is None: + block_op = NoOp + if block_params is None: + block_params = {} + + n_channels = input_size[0] + input_size_ = np.array(input_size[1:]) + + if not isinstance(fmap_sizes, list) and not isinstance(fmap_sizes, tuple): + raise AttributeError("fmap_sizes has to be either a list or tuple or an int") + elif len(fmap_sizes) < 2: + raise AttributeError("fmap_sizes has to contain at least three elements") + else: + h_size_bot = fmap_sizes[0] + + # We need to know how many layers we will use at the beginning + input_size_new = input_size_ // (2 ** len(fmap_sizes)) + if np.min(input_size_new) < 2 and z_dim is not None: + raise AttributeError("fmap_sizes to long, one image dimension has already perished") + + ### Start block + start_block = [] + + if not to_1x1: + kernel_size_start = [min(conv_params["kernel_size"], i) for i in input_size_new] + else: + kernel_size_start = input_size_new.tolist() + + if z_dim is not None: + self.start = ConvModule( + z_dim, + h_size_bot, + conv_op=upsample_op, + conv_params=dict(kernel_size=kernel_size_start, stride=1, padding=0, bias=False), + normalization_op=normalization_op, + normalization_params=normalization_params, + activation_op=activation_op, + activation_params=activation_params, + ) + + input_size_new = input_size_new * 2 + else: + self.start = NoOp() + + ### Middle block (Done until we reach ? x input_size/2 x input_size/2) + self.middle_blocks = nn.ModuleList() + + for h_size_top in fmap_sizes[1:]: + + self.middle_blocks.append(block_op(h_size_bot, **block_params)) + + self.middle_blocks.append( + ConvModule( + h_size_bot, + h_size_top, + conv_op=upsample_op, + conv_params=conv_params, + normalization_op=normalization_op, + normalization_params={}, + activation_op=activation_op, + activation_params=activation_params, + ) + ) + + h_size_bot = h_size_top + input_size_new = input_size_new * 2 + + ### End block + self.end = ConvModule( + h_size_bot, + n_channels, + conv_op=upsample_op, + conv_params=conv_params, + normalization_op=None, + activation_op=None, + ) + + def forward(self, inpt, **kwargs): + output = self.start(inpt, **kwargs) + for middle in self.middle_blocks: + output = middle(output, **kwargs) + output = self.end(output, **kwargs) + return output + + +# Basic Encoder +class BasicEncoder(nn.Module): + def __init__( + self, + input_size, + z_dim=256, + fmap_sizes=(64, 128, 256), + conv_op=nn.Conv2d, + conv_params=None, + normalization_op=NoOp, + normalization_params=None, + activation_op=nn.LeakyReLU, + activation_params=None, + block_op=NoOp, + block_params=None, + to_1x1=True, + ): + """Basic configureable Encoder. + Allows for mutilple "feature-map" levels defined by the feature map size, where for each feature map size a conv operation + optional conv block is used. + + Args: + z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). + fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each + int defines the number of feature maps in the layer]. Defaults to (64, 128, 256). + conv_op ([torch.nn.Module], optional): [Convolutioon operation used to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d. + conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). + normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. + normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. + activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. + activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. + block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. + block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. + to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ]. Defaults to True. + """ + super(BasicEncoder, self).__init__() + + if conv_params is None: + conv_params = dict(kernel_size=3, stride=2, padding=1, bias=False) + if block_op is None: + block_op = NoOp + if block_params is None: + block_params = {} + + n_channels = input_size[0] + input_size_new = np.array(input_size[1:]) + + if not isinstance(fmap_sizes, list) and not isinstance(fmap_sizes, tuple): + raise AttributeError("fmap_sizes has to be either a list or tuple or an int") + # elif len(fmap_sizes) < 2: + # raise AttributeError("fmap_sizes has to contain at least three elements") + else: + h_size_bot = fmap_sizes[0] + + ### Start block + self.start = ConvModule( + n_channels, + h_size_bot, + conv_op=conv_op, + conv_params=conv_params, + normalization_op=normalization_op, + normalization_params={}, + activation_op=activation_op, + activation_params=activation_params, + ) + input_size_new = input_size_new // 2 + + ### Middle block (Done until we reach ? x 4 x 4) + self.middle_blocks = nn.ModuleList() + + for h_size_top in fmap_sizes[1:]: + + self.middle_blocks.append(block_op(h_size_bot, **block_params)) + + self.middle_blocks.append( + ConvModule( + h_size_bot, + h_size_top, + conv_op=conv_op, + conv_params=conv_params, + normalization_op=normalization_op, + normalization_params={}, + activation_op=activation_op, + activation_params=activation_params, + ) + ) + + h_size_bot = h_size_top + input_size_new = input_size_new // 2 + + if np.min(input_size_new) < 2 and z_dim is not None: + raise ("fmap_sizes to long, one image dimension has already perished") + + ### End block + if not to_1x1: + kernel_size_end = [min(conv_params["kernel_size"], i) for i in input_size_new] + else: + kernel_size_end = input_size_new.tolist() + + if z_dim is not None: + self.end = ConvModule( + h_size_bot, + z_dim, + conv_op=conv_op, + conv_params=dict(kernel_size=kernel_size_end, stride=1, padding=0, bias=False), + normalization_op=None, + activation_op=None, + ) + + if to_1x1: + self.output_size = (z_dim, 1, 1) + else: + self.output_size = (z_dim, *[i - (j - 1) for i, j in zip(input_size_new, kernel_size_end)]) + else: + self.end = NoOp() + self.output_size = input_size_new + + def forward(self, inpt, **kwargs): + output = self.start(inpt, **kwargs) + for middle in self.middle_blocks: + output = middle(output, **kwargs) + output = self.end(output, **kwargs) + return output diff --git a/example_algos/requirements_algos.txt b/example_algos/requirements_algos.txt new file mode 100644 index 0000000..8b34e0f --- /dev/null +++ b/example_algos/requirements_algos.txt @@ -0,0 +1,16 @@ +click==7.1.2 +joblib==0.14.1 +monai==0.1.0 +nibabel==3.1.0 +numpy==1.18.3 +packaging==20.3 +pkg-resources==0.0.0 +pyparsing==2.4.7 +scikit-learn==0.22.2.post1 +scipy==1.4.1 +six==1.14.0 +tensorboard==2.2.1 +torch==1.5.0 +torchvision==0.6.0 +tqdm==4.46.0 +trixi @ git+https://github.com/MIC-DKFZ/trixi \ No newline at end of file diff --git a/example_algos/util/__init__.py b/example_algos/util/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/example_algos/util/__init__.py @@ -0,0 +1 @@ + diff --git a/example_algos/util/nifti_io.py b/example_algos/util/nifti_io.py new file mode 100644 index 0000000..a932d75 --- /dev/null +++ b/example_algos/util/nifti_io.py @@ -0,0 +1,30 @@ + + +def ni_load(f_path): + """Loads a nifti file from a given path and returns the data and affine matrix + + Args: + f_path ([str]): [Path to the nifti file] + + Returns: + data [np.ndarray]: [Nifti file image data] + affine [np.ndarray]: [Nifti file affine matrix] + """ + import nibabel + + nimg = nibabel.load(f_path) + return nimg.get_fdata(), nimg.affine + + +def ni_save(f_path, ni_data, ni_affine): + """Saves image data and a affine matrix as a new nifti file + + Args: + f_path ([str]): [Path to the nifti file] + ni_data ([np.ndarray]): [Image data] + ni_affine ([type]): [Affine matrix] + """ + import nibabel + + nimg = nibabel.Nifti1Image(ni_data, ni_affine) + nibabel.save(nimg, f_path)