diff --git a/bin/hyppopy_exe.py b/bin/hyppopy_exe.py index ef0610e..c0f4383 100644 --- a/bin/hyppopy_exe.py +++ b/bin/hyppopy_exe.py @@ -1,69 +1,72 @@ #!/usr/bin/env python # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from hyppopy.projectmanager import ProjectManager from hyppopy.workflows.unet_usecase.unet_usecase import unet_usecase from hyppopy.workflows.svc_usecase.svc_usecase import svc_usecase from hyppopy.workflows.randomforest_usecase.randomforest_usecase import randomforest_usecase +from hyppopy.workflows.imageregistration_usecase.imageregistration_usecase import imageregistration_usecase import os import sys import argparse def print_warning(msg): print("\n!!!!! WARNING !!!!!") print(msg) sys.exit() def args_check(args): if not args.workflow: print_warning("No workflow specified, check --help") if not args.config: print_warning("Missing config parameter, check --help") if not os.path.isfile(args.config): print_warning(f"Couldn't find configfile ({args.config}), please check your input --config") if __name__ == "__main__": parser = argparse.ArgumentParser(description='UNet Hyppopy UseCase Example Optimization.') parser.add_argument('-w', '--workflow', type=str, help='workflow to be executed') parser.add_argument('-c', '--config', type=str, help='config filename, .xml or .json formats are supported.' 'pass a full path filename or the filename only if the' 'configfile is in the data folder') args = parser.parse_args() args_check(args) ProjectManager.read_config(args.config) if args.workflow == "svc_usecase": uc = svc_usecase() elif args.workflow == "randomforest_usecase": uc = randomforest_usecase() elif args.workflow == "unet_usecase": uc = unet_usecase() + elif args.workflow == "imageregistration_usecase": + uc = imageregistration_usecase() else: print("No workflow called {} found!".format(args.workflow)) sys.exit() uc.run() print(uc.get_results()) diff --git a/hyppopy/workflows/imageregistration_usecase/__init__.py b/hyppopy/workflows/imageregistration_usecase/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hyppopy/workflows/imageregistration_usecase/imageregistration_usecase.py b/hyppopy/workflows/imageregistration_usecase/imageregistration_usecase.py new file mode 100644 index 0000000..993c72a --- /dev/null +++ b/hyppopy/workflows/imageregistration_usecase/imageregistration_usecase.py @@ -0,0 +1,52 @@ +# DKFZ +# +# +# Copyright (c) German Cancer Research Center, +# Division of Medical and Biological Informatics. +# All rights reserved. +# +# This software is distributed WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. +# +# See LICENSE.txt or http://www.mitk.org for details. +# +# Author: + +#------------------------------------------------------ +# this needs to be imported, dont remove these +from hyppopy.projectmanager import ProjectManager +from hyppopy.workflows.workflowbase import WorkflowBase +#------------------------------------------------------ + +# import your external packages +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import cross_val_score + +# import your custom DataLoader +from hyppopy.workflows.dataloader.simpleloader import SimpleDataLoaderBase # This is a dataloader class create your own + + +class imageregistration_usecase(WorkflowBase): + + def setup(self): + # here you create your own DataLoader instance + dl = SimpleDataLoaderBase() + # call the start function of your DataLoader + dl.start(path=ProjectManager.data_path, + data_name=ProjectManager.data_name, + labels_name=ProjectManager.labels_name) + # pass the data to the solver + self.solver.set_data(dl.data) + + def blackbox_function(self, data, params): + # converting number back to integers is an ugly hack that will be removed in the future + if "n_estimators" in params.keys(): + params["n_estimators"] = int(round(params["n_estimators"])) + + # Do your training + clf = RandomForestClassifier(**params) + # compute your loss + loss = -cross_val_score(estimator=clf, X=data[0], y=data[1], cv=3).mean() + # return loss + return loss