diff --git a/__main__.py b/__main__.py new file mode 100644 index 0000000..f9af847 --- /dev/null +++ b/__main__.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# +# DKFZ +# +# +# Copyright (c) German Cancer Research Center, +# Division of Medical and Biological Informatics. +# All rights reserved. +# +# This software is distributed WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. +# +# See LICENSE.txt or http://www.mitk.org for details. +# +# Author: Sven Wanner (s.wanner@dkfz.de) + +import os +import sys +ROOT = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(ROOT) + +from hyppopy.projectmanager import ProjectManager +from hyppopy.workflows.unet_usecase.unet_usecase import unet_usecase +from hyppopy.workflows.svc_usecase.svc_usecase import svc_usecase +from hyppopy.workflows.randomforest_usecase.randomforest_usecase import randomforest_usecase +from hyppopy.workflows.imageregistration_usecase.imageregistration_usecase import imageregistration_usecase + + +import os +import sys +import time +import argparse + + +def print_warning(msg): + print("\n!!!!! WARNING !!!!!") + print(msg) + sys.exit() + + +def args_check(args): + if not args.workflow: + print_warning("No workflow specified, check --help") + if not args.config: + print_warning("Missing config parameter, check --help") + if not os.path.isfile(args.config): + print_warning(f"Couldn't find configfile ({args.config}), please check your input --config") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='UNet Hyppopy UseCase Example Optimization.') + parser.add_argument('-w', '--workflow', type=str, help='workflow to be executed') + parser.add_argument('-o', '--output', type=str, default=None, help='output path to store result') + parser.add_argument('-c', '--config', type=str, help='config filename, .xml or .json formats are supported.' + 'pass a full path filename or the filename only if the' + 'configfile is in the data folder') + + args = parser.parse_args() + args_check(args) + + ProjectManager.read_config(args.config) + + if args.output is not None: + ProjectManager.output_dir = args.output + + if args.workflow == "svc_usecase": + uc = svc_usecase() + elif args.workflow == "randomforest_usecase": + uc = randomforest_usecase() + elif args.workflow == "unet_usecase": + uc = unet_usecase() + elif args.workflow == "imageregistration_usecase": + uc = imageregistration_usecase() + else: + print("No workflow called {} found!".format(args.workflow)) + sys.exit() + + print("\nStart optimization...") + start = time.process_time() + uc.run() + end = time.process_time() + + print("Finished optimization!\n") + print("Total Time: {}s\n".format(end-start)) + res, best = uc.get_results() + print("---- Optimal Parameter -----\n") + for p in best.items(): + print(" - {}\t:\t{}".format(p[0], p[1])) diff --git a/examples/use_hyppopy_solver.py b/examples/use_hyppopy_solver.py index 8843d4a..93965da 100644 --- a/examples/use_hyppopy_solver.py +++ b/examples/use_hyppopy_solver.py @@ -1,47 +1,47 @@ import os import sys from sklearn.svm import SVC from sklearn.model_selection import cross_val_score # the ProjectManager is loading your config file and giving you access # to everything specified in the settings/custom section of the config from hyppopy.projectmanager import ProjectManager # the SolverFactory builds the Solver class for you from hyppopy.solverfactory import SolverFactory # we use in this example the SimpleDataLoader from hyppopy.workflows.dataloader.simpleloader import SimpleDataLoader # until Hyppopy is not fully installable we need # to set the Hyppopy package folder by hand HYPPOPY_DIR = "D:/MyPythonModules/hyppopy" sys.path.append(HYPPOPY_DIR) # let the ProjectManager read your config file DATA = os.path.join(HYPPOPY_DIR, *("hyppopy", "tests", "data", "Titanic")) ProjectManager.read_config(os.path.join(DATA, 'rf_config.json')) -# ----- reading data ------ +# ----- reading data somehow ------ dl = SimpleDataLoader() dl.start(path=ProjectManager.data_path, data_name=ProjectManager.data_name, labels_name=ProjectManager.labels_name) -# ------------------------- +# --------------------------------- # ----- defining loss function ------ def blackbox_function(params): clf = SVC(**params) return -cross_val_score(estimator=clf, X=dl.data[0], y=dl.data[1], cv=3).mean() # ----------------------------------- # ----- create and run the solver ------ # get a solver instance from the SolverFactory solver = SolverFactory.get_solver() # set your loss function solver.set_loss_function(blackbox_function) # run the solver solver.run() # store your results solver.save_results(savedir="C:\\Users\\Me\\Desktop\\myTestProject") # -------------------------------------- \ No newline at end of file diff --git a/hyppopy/globals.py b/hyppopy/globals.py index 922823c..3b2a6cd 100644 --- a/hyppopy/globals.py +++ b/hyppopy/globals.py @@ -1,30 +1,31 @@ # DKFZ # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. import os import sys import logging ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT) -PLUGIN_DEFAULT_DIR = os.path.join(ROOT, *("hyppopy", "plugins")) -TESTDATA_DIR = os.path.join(ROOT, *("hyppopy", "tests", "data")) +LIBNAME = "hyppopy" +PLUGIN_DEFAULT_DIR = os.path.join(ROOT, *(LIBNAME, "plugins")) +TESTDATA_DIR = os.path.join(ROOT, *(LIBNAME, "tests", "data")) SETTINGSSOLVERPATH = "settings/solver" SETTINGSCUSTOMPATH = "settings/custom" -DEEPDICT_XML_ROOT = "hyppopy" +DEEPDICT_XML_ROOT = LIBNAME -LOGFILENAME = os.path.join(ROOT, 'logfile.log') +LOGFILENAME = os.path.join(ROOT, 'logfile2.log') DEBUGLEVEL = logging.DEBUG logging.basicConfig(filename=LOGFILENAME, filemode='w', format='%(levelname)s: %(name)s - %(message)s') diff --git a/hyppopy/plugins/hyperopt_settings_plugin.py b/hyppopy/plugins/hyperopt_settings_plugin.py index 8ce1f40..a188ef6 100644 --- a/hyppopy/plugins/hyperopt_settings_plugin.py +++ b/hyppopy/plugins/hyperopt_settings_plugin.py @@ -1,105 +1,103 @@ -# -*- coding: utf-8 -*- -# # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging import numpy as np from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from pprint import pformat try: from hyperopt import hp from yapsy.IPlugin import IPlugin except: LOG.warning("hyperopt package not installed, will ignore this plugin!") print("hyperopt package not installed, will ignore this plugin!") from hyppopy.settingspluginbase import SettingsPluginBase from hyppopy.settingsparticle import SettingsParticle class hyperopt_Settings(SettingsPluginBase, IPlugin): def __init__(self): SettingsPluginBase.__init__(self) LOG.debug("initialized") def convert_parameter(self, input_dict): LOG.debug("convert input parameter\n\n\t{}\n".format(pformat(input_dict))) solution_space = {} for name, content in input_dict.items(): particle = hyperopt_SettingsParticle(name=name) for key, value in content.items(): if key == 'domain': particle.domain = value elif key == 'data': particle.data = value elif key == 'type': particle.dtype = value solution_space[name] = particle.get() return solution_space class hyperopt_SettingsParticle(SettingsParticle): def __init__(self, name=None, domain=None, dtype=None, data=None): SettingsParticle.__init__(self, name, domain, dtype, data) def convert(self): if self.domain == "uniform": if self.dtype == "float" or self.dtype == "double": return hp.uniform(self.name, self.data[0], self.data[1]) elif self.dtype == "int": data = list(np.arange(int(self.data[0]), int(self.data[1]+1))) return hp.choice(self.name, data) else: msg = "cannot convert the type {} in domain {}".format(self.dtype, self.domain) LOG.error(msg) raise LookupError(msg) elif self.domain == "loguniform": if self.dtype == "float" or self.dtype == "double": return hp.loguniform(self.name, self.data[0], self.data[1]) else: msg = "cannot convert the type {} in domain {}".format(self.dtype, self.domain) LOG.error(msg) raise LookupError(msg) elif self.domain == "normal": if self.dtype == "float" or self.dtype == "double": return hp.normal(self.name, self.data[0], self.data[1]) else: msg = "cannot convert the type {} in domain {}".format(self.dtype, self.domain) LOG.error(msg) raise LookupError(msg) elif self.domain == "categorical": if self.dtype == 'str': return hp.choice(self.name, self.data) elif self.dtype == 'bool': data = [] for elem in self.data: if elem == "true" or elem == "True" or elem == 1 or elem == "1": data .append(True) elif elem == "false" or elem == "False" or elem == 0 or elem == "0": data .append(False) else: msg = "cannot convert the type {} in domain {}, unknown bool type value".format(self.dtype, self.domain) LOG.error(msg) raise LookupError(msg) return hp.choice(self.name, data) diff --git a/hyppopy/plugins/hyperopt_solver_plugin.py b/hyppopy/plugins/hyperopt_solver_plugin.py index 587bf07..4c5a5bc 100644 --- a/hyppopy/plugins/hyperopt_solver_plugin.py +++ b/hyppopy/plugins/hyperopt_solver_plugin.py @@ -1,80 +1,80 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from pprint import pformat from hyperopt import fmin, tpe, hp, STATUS_OK, STATUS_FAIL, Trials from yapsy.IPlugin import IPlugin from hyppopy.projectmanager import ProjectManager from hyppopy.solverpluginbase import SolverPluginBase class hyperopt_Solver(SolverPluginBase, IPlugin): trials = None best = None def __init__(self): SolverPluginBase.__init__(self) LOG.debug("initialized") - def loss_function(self, params): + def blackbox_function(self, params): try: - loss = self.loss(self.data, params) + loss = self.blackbox_function_template(self.data, params) status = STATUS_OK except Exception as e: LOG.error("execution of self.loss(self.data, params) failed due to:\n {}".format(e)) status = STATUS_FAIL return {'loss': loss, 'status': status} def execute_solver(self, parameter): LOG.debug("execute_solver using solution space:\n\n\t{}\n".format(pformat(parameter))) self.trials = Trials() try: - self.best = fmin(fn=self.loss_function, + self.best = fmin(fn=self.blackbox_function, space=parameter, algo=tpe.suggest, max_evals=ProjectManager.max_iterations, trials=self.trials) except Exception as e: msg = "internal error in hyperopt.fmin occured. {}".format(e) LOG.error(msg) raise BrokenPipeError(msg) def convert_results(self): # currently converting results in a way that this function returns a dict # keeping all useful parameter as key/list item. This will be automatically # converted to a pandas dataframe in the solver class - results = {'timing ms': [], 'losses': []} + results = {'duration': [], 'losses': []} pset = self.trials.trials[0]['misc']['vals'] for p in pset.keys(): results[p] = [] for n, trial in enumerate(self.trials.trials): t1 = trial['book_time'] t2 = trial['refresh_time'] - results['timing ms'].append((t2 - t1).microseconds/1000.0) + results['duration'].append((t2 - t1).microseconds/1000.0) results['losses'].append(trial['result']['loss']) pset = trial['misc']['vals'] for p in pset.items(): results[p[0]].append(p[1][0]) return results, self.best diff --git a/hyppopy/plugins/optunity_solver_plugin.py b/hyppopy/plugins/optunity_solver_plugin.py index 895d3a5..1039874 100644 --- a/hyppopy/plugins/optunity_solver_plugin.py +++ b/hyppopy/plugins/optunity_solver_plugin.py @@ -1,64 +1,67 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from pprint import pformat import optunity from yapsy.IPlugin import IPlugin from hyppopy.projectmanager import ProjectManager from hyppopy.solverpluginbase import SolverPluginBase class optunity_Solver(SolverPluginBase, IPlugin): solver_info = None trials = None best = None status = None def __init__(self): SolverPluginBase.__init__(self) LOG.debug("initialized") - def loss_function(self, **params): + def blackbox_function(self, **params): try: - loss = self.loss(self.data, params) + for key in params.keys(): + if self.settings.get_type_of(key) == 'int': + params[key] = int(round(params[key])) + loss = self.blackbox_function_template(self.data, params) self.status.append('ok') return loss except Exception as e: LOG.error("computing loss failed due to:\n {}".format(e)) self.status.append('fail') return 1e9 def execute_solver(self, parameter): LOG.debug("execute_solver using solution space:\n\n\t{}\n".format(pformat(parameter))) self.status = [] try: - self.best, self.trials, self.solver_info = optunity.minimize_structured(f=self.loss_function, + self.best, self.trials, self.solver_info = optunity.minimize_structured(f=self.blackbox_function, num_evals=ProjectManager.max_iterations, search_space=parameter) except Exception as e: LOG.error("internal error in optunity.minimize_structured occured. {}".format(e)) raise BrokenPipeError("internal error in optunity.minimize_structured occured. {}".format(e)) def convert_results(self): results = self.trials.call_log['args'] results['losses'] = self.trials.call_log['values'] return results, self.best diff --git a/hyppopy/resultviewer.py b/hyppopy/resultviewer.py new file mode 100644 index 0000000..096fe91 --- /dev/null +++ b/hyppopy/resultviewer.py @@ -0,0 +1,83 @@ +# DKFZ +# +# +# Copyright (c) German Cancer Research Center, +# Division of Medical and Biological Informatics. +# All rights reserved. +# +# This software is distributed WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. +# +# See LICENSE.txt or http://www.mitk.org for details. +# +# Author: Sven Wanner (s.wanner@dkfz.de) + +import os +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt + +import logging +from hyppopy.globals import DEBUGLEVEL +LOG = logging.getLogger(os.path.basename(__file__)) +LOG.setLevel(DEBUGLEVEL) + +sns.set(style="darkgrid") + + +class ResultViewer(object): + + def __init__(self, fname=None, save_only=False): + self.df = None + self.has_duration = False + self.hyperparameter = None + self.save_only = save_only + self.path = None + self.appendix = None + if fname is not None: + self.read(fname) + + def read(self, fname): + self.path = os.path.dirname(fname) + split = os.path.basename(fname).split("_") + self.appendix = split[-2]+"_"+split[-1] + self.appendix = self.appendix[:-4] + self.df = pd.read_csv(fname, index_col=0) + const_data = ["duration", "losses"] + hyperparameter_columns = [item for item in self.df.columns if item not in const_data] + self.hyperparameter = pd.DataFrame() + for key in hyperparameter_columns: + self.hyperparameter[key] = self.df[key] + self.has_duration = "duration" in self.df.columns + + def show(self, save=True): + if self.has_duration: + sns_plot = sns.jointplot(y="duration", x="losses", data=self.df, kind="kde") + if not self.save_only: + plt.show() + if save: + save_name = os.path.join(self.path, "t_vs_loss_"+self.appendix+".png") + try: + sns_plot.savefig(save_name) + except Exception as e: + msg = "failed to save file {}, reason {}".format(save_name, e) + LOG.error(msg) + raise IOError(msg) + sns_plot = sns.pairplot(self.df, height=1.8, aspect=1.8, + plot_kws=dict(edgecolor="k", linewidth=0.5), + diag_kind="kde", diag_kws=dict(shade=True)) + + fig = sns_plot.fig + fig.subplots_adjust(top=0.93, wspace=0.3) + t = fig.suptitle('Pairwise Plots', fontsize=14) + if not self.save_only: + plt.show() + if save: + save_name = os.path.join(self.path, "matrixview_"+self.appendix+".png") + try: + sns_plot.savefig(save_name) + except Exception as e: + msg = "failed to save file {}, reason {}".format(save_name, e) + LOG.error(msg) + raise IOError(msg) diff --git a/hyppopy/settingsparticle.py b/hyppopy/settingsparticle.py index 60e9ce4..fc4c5cf 100644 --- a/hyppopy/settingsparticle.py +++ b/hyppopy/settingsparticle.py @@ -1,86 +1,90 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import abc import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class SettingsParticle(object): domains = ["uniform", "loguniform", "normal", "categorical"] _name = None _domain = None _dtype = None _data = None def __init__(self, name=None, domain=None, dtype=None, data=None): if name is not None: self.name = name if domain is not None: self.domain = domain if dtype is not None: self.dtype = dtype if data is not None: self.data = data @abc.abstractmethod def convert(self): raise NotImplementedError("the user has to implement this function") def get(self): msg = None if self.name is None: msg = "cannot convert unnamed parameter" if self.domain is None: msg = "cannot convert parameter of empty domain" if self.dtype is None: msg = "cannot convert parameter with unknown dtype" if self.data is None: msg = "cannot convert parameter having no data" if msg is not None: LOG.error(msg) raise LookupError(msg) return self.convert() @property def name(self): return self._name @name.setter def name(self, value): self._name = value @property def domain(self): return self._domain @domain.setter def domain(self, value): + if not value in self.domains: + msg = "domain named {} not available, check your domain name or implement new domain!".format(value) + LOG.error(msg) + raise LookupError(msg) self._domain = value @property def dtype(self): return self._dtype @dtype.setter def dtype(self, value): self._dtype = value @property def data(self): return self._data @data.setter def data(self, value): self._data = value diff --git a/hyppopy/settingspluginbase.py b/hyppopy/settingspluginbase.py index 94cc169..8bb244e 100644 --- a/hyppopy/settingspluginbase.py +++ b/hyppopy/settingspluginbase.py @@ -1,76 +1,97 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import abc import os import copy import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from hyppopy.deepdict import DeepDict class SettingsPluginBase(object): _data = None _name = None def __init__(self): self._data = {} @abc.abstractmethod def convert_parameter(self): raise NotImplementedError('users must define convert_parameter to use this base class') def get_hyperparameter(self): return self.convert_parameter(self.data) def set_hyperparameter(self, input_data): self.data.clear() self.data = copy.deepcopy(input_data) + def get_type_of(self, name): + if not name in self.data: + msg = "hyperparameter named {} not found!".format(name) + LOG.error(msg) + raise LookupError(msg) + return self.data[name]["type"] + + def get_domain_of(self, name): + if not name in self.data: + msg = "hyperparameter named {} not found!".format(name) + LOG.error(msg) + raise LookupError(msg) + return self.data[name]["domain"] + + def get_data_of(self, name): + if not name in self.data: + msg = "hyperparameter named {} not found!".format(name) + LOG.error(msg) + raise LookupError(msg) + return self.data[name]["data"] + def read(self, fname): self.data.clear() self.data.from_file(fname) def write(self, fname): self.data.to_file(fname) @property def data(self): return self._data @data.setter def data(self, value): if isinstance(value, dict): self._data = value elif isinstance(value, DeepDict): self._data = value.data else: raise IOError("unexpected input type({}) for data, needs to be of type dict or DeepDict!".format(type(value))) @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): LOG.error("Invalid input, str type expected for value, got {} instead".format(type(value))) raise IOError("Invalid input, str type expected for value, got {} instead".format(type(value))) self._name = value diff --git a/hyppopy/solver.py b/hyppopy/solver.py index c9bbff5..4d4a183 100644 --- a/hyppopy/solver.py +++ b/hyppopy/solver.py @@ -1,117 +1,126 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from hyppopy.projectmanager import ProjectManager +from hyppopy.resultviewer import ResultViewer import os import datetime import logging import pandas as pd +from hyppopy.globals import LIBNAME from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class Solver(object): _name = None _solver_plugin = None _settings_plugin = None def __init__(self): pass def set_data(self, data): - self.solver.set_data(data) + self._solver_plugin.set_data(data) def set_hyperparameters(self, params): - self.settings.set_hyperparameter(params) + self.settings_plugin.set_hyperparameter(params) - def set_loss_function(self, loss_func): - self.solver.set_loss_function(loss_func) + def set_loss_function(self, func): + self._solver_plugin.set_blackbox_function(func) def run(self): if not ProjectManager.is_ready(): LOG.error("No config data found to initialize PluginSetting object") raise IOError("No config data found to initialize PluginSetting object") - self.settings.set_hyperparameter(ProjectManager.get_hyperparameter()) - self.solver.settings = self.settings - self.solver.run() + self.settings_plugin.set_hyperparameter(ProjectManager.get_hyperparameter()) + self._solver_plugin.settings = self.settings_plugin + self._solver_plugin.run() - def save_results(self, savedir=None, savename=None): + def save_results(self, savedir=None, savename=None, show=False): df, best = self.get_results() dir = None if savename is None: - savename = "hypopy" + savename = LIBNAME if savedir is None: if 'output_dir' in ProjectManager.__dict__.keys(): if not os.path.isdir(ProjectManager.output_dir): os.mkdir(ProjectManager.output_dir) dir = ProjectManager.output_dir else: print("WARNING: No solver option output_dir found, cannot save results!") LOG.warning("WARNING: No solver option output_dir found, cannot save results!") else: dir = savedir if not os.path.isdir(savedir): os.mkdir(savedir) tstr = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") name = savename + "_all_" + tstr + ".csv" - fname = os.path.join(dir, name) - df.to_csv(fname) + fname_all = os.path.join(dir, name) + df.to_csv(fname_all) name = savename + "_best_" + tstr + ".txt" - fname = os.path.join(dir, name) - with open(fname, "w") as text_file: + fname_best = os.path.join(dir, name) + with open(fname_best, "w") as text_file: for item in best.items(): text_file.write("{}\t:\t{}\n".format(item[0], item[1])) + if show: + viewer = ResultViewer(fname_all) + viewer.show() + else: + viewer = ResultViewer(fname_all, save_only=True) + viewer.show() + def get_results(self): - results, best = self.solver.get_results() + results, best = self._solver_plugin.get_results() df = pd.DataFrame.from_dict(results) return df, best @property def is_ready(self): - return self.solver is not None and self.settings is not None + return self._solver_plugin is not None and self.settings_plugin is not None @property - def solver(self): + def solver_plugin(self): return self._solver_plugin - @solver.setter - def solver(self, value): + @solver_plugin.setter + def solver_plugin(self, value): self._solver_plugin = value @property - def settings(self): + def settings_plugin(self): return self._settings_plugin - @settings.setter - def settings(self, value): + @settings_plugin.setter + def settings_plugin(self, value): self._settings_plugin = value @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): msg = "Invalid input, str type expected for value, got {} instead".format(type(value)) LOG.error(msg) raise IOError(msg) self._name = value diff --git a/hyppopy/solverfactory.py b/hyppopy/solverfactory.py index bcefca1..fb053b6 100644 --- a/hyppopy/solverfactory.py +++ b/hyppopy/solverfactory.py @@ -1,165 +1,165 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from yapsy.PluginManager import PluginManager from hyppopy.projectmanager import ProjectManager from hyppopy.globals import PLUGIN_DEFAULT_DIR from hyppopy.deepdict import DeepDict from hyppopy.solver import Solver from hyppopy.singleton import * import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) @singleton_object class SolverFactory(metaclass=Singleton): """ This class is responsible for grabbing all plugins from the plugin folder arranging them into a Solver class instances. These Solver class instances can be requested from the factory via the get_solver method. The SolverFactory class is a Singleton class, so try not to instantiate it using SolverFactory(), the consequences will be horrific. Instead use is like a class having static functions only, SolverFactory.method(). """ _plugin_dirs = [] _plugins = {} def __init__(self): self.reset() self.load_plugins() LOG.debug("Solverfactory initialized") def load_plugins(self): """ Load plugin modules from plugin paths """ LOG.debug("load_plugins()") manager = PluginManager() LOG.debug("setPluginPlaces(" + " ".join(map(str, self._plugin_dirs))) manager.setPluginPlaces(self._plugin_dirs) manager.collectPlugins() for plugin in manager.getAllPlugins(): name_elements = plugin.plugin_object.__class__.__name__.split("_") LOG.debug("found plugin " + " ".join(map(str, name_elements))) print("Solverfactory: found plugins " + " ".join(map(str, name_elements))) if len(name_elements) != 2 or ("Solver" not in name_elements and "Settings" not in name_elements): msg = "invalid plugin class naming for class {}, the convention is libname_Solver or libname_Settings.".format(plugin.plugin_object.__class__.__name__) LOG.error(msg) raise NameError(msg) if name_elements[0] not in self._plugins.keys(): self._plugins[name_elements[0]] = Solver() self._plugins[name_elements[0]].name = name_elements[0] if name_elements[1] == "Solver": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] - self._plugins[name_elements[0]].solver = obj + self._plugins[name_elements[0]].solver_plugin = obj LOG.info("plugin: {} Solver loaded".format(name_elements[0])) except Exception as e: msg = "failed to instanciate class {}".format(plugin.plugin_object.__class__.__name__) LOG.error(msg) raise ImportError(msg) elif name_elements[1] == "Settings": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] - self._plugins[name_elements[0]].settings = obj + self._plugins[name_elements[0]].settings_plugin = obj LOG.info("plugin: {} ParameterSpace loaded".format(name_elements[0])) except Exception as e: msg = "failed to instanciate class {}".format(plugin.plugin_object.__class__.__name__) LOG.error(msg) raise ImportError(msg) else: msg = "failed loading plugin {}, please check if naming conventions are kept!".format(name_elements[0]) LOG.error(msg) raise IOError(msg) if len(self._plugins) == 0: msg = "no plugins found, please check your plugin folder names or your plugin scripts for errors!" LOG.error(msg) raise IOError(msg) def reset(self): """ Reset solver factory """ LOG.debug("reset()") self._plugins = {} self._plugin_dirs = [] self.add_plugin_dir(os.path.abspath(PLUGIN_DEFAULT_DIR)) def add_plugin_dir(self, dir): """ Add plugin directory """ LOG.debug("add_plugin_dir({})".format(dir)) self._plugin_dirs.append(dir) def list_solver(self): """ list all solvers available :return: [list(str)] """ return list(self._plugins.keys()) def from_settings(self, settings): if isinstance(settings, str): if not os.path.isfile(settings): LOG.error("input error, file {} not found!".format(settings)) if not ProjectManager.read_config(settings): LOG.error("failed to read config in ProjectManager!") return None else: if not ProjectManager.set_config(settings): LOG.error("failed to set config in ProjectManager!") return None if not ProjectManager.is_ready(): LOG.error("failed to set config in ProjectManager!") return None try: solver = self.get_solver(ProjectManager.use_plugin) except Exception as e: msg = "failed to create solver, reason {}".format(e) LOG.error(msg) return None return solver def get_solver(self, name=None): """ returns a solver by name tag :param name: [str] solver name :return: [Solver] instance """ if name is None: try: name = ProjectManager.use_plugin except Exception as e: msg = "failed to setup solver, no solver specified, check your ProjectManager for the use_plugin value!" LOG.error(msg) raise LookupError(msg) if not isinstance(name, str): msg = "Invalid input, str type expected for name, got {} instead".format(type(name)) LOG.error(msg) raise IOError(msg) if name not in self.list_solver(): msg = "failed solver request, a solver called {} is not available, check for typo or if your plugin failed while loading!".format(name) LOG.error(msg) raise LookupError(msg) LOG.debug("get_solver({})".format(name)) return self._plugins[name] diff --git a/hyppopy/solverpluginbase.py b/hyppopy/solverpluginbase.py index 97485a5..7e3ddc4 100644 --- a/hyppopy/solverpluginbase.py +++ b/hyppopy/solverpluginbase.py @@ -1,93 +1,92 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import abc import os -import time import logging from hyppopy.globals import DEBUGLEVEL from hyppopy.settingspluginbase import SettingsPluginBase LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class SolverPluginBase(object): _data = None - _loss = None + _blackbox_function_template = None _settings = None _name = None def __init__(self): pass @abc.abstractmethod - def loss_function(self, params): + def blackbox_function(self, params): raise NotImplementedError('users must define loss_func to use this base class') @abc.abstractmethod def execute_solver(self): raise NotImplementedError('users must define execute_solver to use this base class') @abc.abstractmethod def convert_results(self): raise NotImplementedError('users must define convert_results to use this base class') def set_data(self, data): self._data = data - def set_loss_function(self, func): - self._loss = func + def set_blackbox_function(self, func): + self._blackbox_function_template = func def get_results(self): return self.convert_results() def run(self): self.execute_solver(self.settings.get_hyperparameter()) @property def data(self): return self._data @property - def loss(self): - return self._loss + def blackbox_function_template(self): + return self._blackbox_function_template @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): msg = "Invalid input, str type expected for value, got {} instead".format(type(value)) LOG.error(msg) raise IOError(msg) self._name = value @property def settings(self): return self._settings @settings.setter def settings(self, value): if not isinstance(value, SettingsPluginBase): msg = "Invalid input, SettingsPluginBase type expected for value, got {} instead".format(type(value)) LOG.error(msg) raise IOError(msg) self._settings = value diff --git a/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py b/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py index 632e83a..392b9b7 100644 --- a/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py +++ b/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py @@ -1,38 +1,36 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import cross_val_score from hyppopy.projectmanager import ProjectManager from hyppopy.workflows.workflowbase import WorkflowBase from hyppopy.workflows.dataloader.simpleloader import SimpleDataLoader class randomforest_usecase(WorkflowBase): def setup(self, **kwargs): dl = SimpleDataLoader() dl.start(path=ProjectManager.data_path, data_name=ProjectManager.data_name, labels_name=ProjectManager.labels_name) self.solver.set_data(dl.data) def blackbox_function(self, data, params): - if "n_estimators" in params.keys(): - params["n_estimators"] = int(round(params["n_estimators"])) clf = RandomForestClassifier(**params) return -cross_val_score(estimator=clf, X=data[0], y=data[1], cv=3).mean() diff --git a/requirements.txt b/requirements.txt index aeccb6e..7688b4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,16 @@ dicttoxml==1.7.4 hyperopt==0.1.1 matplotlib==3.0.2 numpy==1.16.0 Optunity==1.1.1 pytest==4.1.1 scikit-learn==0.20.2 scipy==1.2.0 sklearn==0.0 Sphinx==1.8.3 xmlrunner==1.7.7 xmltodict==0.11.0 Yapsy==1.11.223 pandas==0.24.1 -trixi==0.1.1.6 -torch==1.0.0 medpy==0.3.0 batchgenerators==0.18.1 \ No newline at end of file