diff --git a/examples/use_hyppopy_solver.py b/examples/use_hyppopy_solver.py new file mode 100644 index 0000000..5cc7906 --- /dev/null +++ b/examples/use_hyppopy_solver.py @@ -0,0 +1,47 @@ +import os +import sys +from sklearn.svm import SVC +from sklearn.model_selection import cross_val_score + +# the ProjectManager is loading your config file and giving you access +# to everything specified in the settings/custom section of the config +from hyppopy.projectmanager import ProjectManager + +# the SolverFactory builds the Solver class for you +from hyppopy.solverfactory import SolverFactory + +# we use in this example the SimpleDataLoader +from hyppopy.workflows.dataloader.simpleloader import SimpleDataLoader + +# until Hyppopy is not fully installable we need +# to set the Hyppopy package folder by hand +HYPPOPY_DIR = "D:/Projects/Python/hyppopy" +sys.path.append(HYPPOPY_DIR) + +# let the ProjectManager read your config file +DATA = os.path.join(HYPPOPY_DIR, *("hyppopy", "tests", "data", "Titanic")) +ProjectManager.read_config(os.path.join(DATA, 'rf_config.json')) + +# ----- reading data ------ +dl = SimpleDataLoader() +dl.start(path=ProjectManager.data_path, + data_name=ProjectManager.data_name, + labels_name=ProjectManager.labels_name) +# ------------------------- + +# ----- defining loss function ------ +def blackbox_function(params): + clf = SVC(**params) + return -cross_val_score(estimator=clf, X=dl.data[0], y=dl.data[1], cv=3).mean() +# ----------------------------------- + +# ----- create and run the solver ------ +# get a solver instance from the SolverFactory +solver = SolverFactory.get_solver() +# set your loss function +solver.set_loss_function(blackbox_function) +# run the solver +solver.run() +# store your results +solver.save_results(savedir="C:\\Users\\s635r\\Desktop\\myTestProject") +# -------------------------------------- \ No newline at end of file diff --git a/hyppopy/solver.py b/hyppopy/solver.py index 2ceb972..c9bbff5 100644 --- a/hyppopy/solver.py +++ b/hyppopy/solver.py @@ -1,118 +1,117 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from hyppopy.projectmanager import ProjectManager import os import datetime import logging import pandas as pd from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class Solver(object): _name = None _solver_plugin = None _settings_plugin = None def __init__(self): pass def set_data(self, data): self.solver.set_data(data) def set_hyperparameters(self, params): self.settings.set_hyperparameter(params) def set_loss_function(self, loss_func): self.solver.set_loss_function(loss_func) def run(self): if not ProjectManager.is_ready(): LOG.error("No config data found to initialize PluginSetting object") raise IOError("No config data found to initialize PluginSetting object") - hyps = ProjectManager.get_hyperparameter() - self.settings.set_hyperparameter(hyps) + self.settings.set_hyperparameter(ProjectManager.get_hyperparameter()) self.solver.settings = self.settings self.solver.run() def save_results(self, savedir=None, savename=None): df, best = self.get_results() dir = None if savename is None: savename = "hypopy" if savedir is None: if 'output_dir' in ProjectManager.__dict__.keys(): if not os.path.isdir(ProjectManager.output_dir): os.mkdir(ProjectManager.output_dir) dir = ProjectManager.output_dir else: print("WARNING: No solver option output_dir found, cannot save results!") LOG.warning("WARNING: No solver option output_dir found, cannot save results!") else: + dir = savedir if not os.path.isdir(savedir): os.mkdir(savedir) - dir = savedir tstr = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") name = savename + "_all_" + tstr + ".csv" fname = os.path.join(dir, name) df.to_csv(fname) name = savename + "_best_" + tstr + ".txt" fname = os.path.join(dir, name) with open(fname, "w") as text_file: for item in best.items(): text_file.write("{}\t:\t{}\n".format(item[0], item[1])) def get_results(self): results, best = self.solver.get_results() df = pd.DataFrame.from_dict(results) return df, best @property def is_ready(self): return self.solver is not None and self.settings is not None @property def solver(self): return self._solver_plugin @solver.setter def solver(self, value): self._solver_plugin = value @property def settings(self): return self._settings_plugin @settings.setter def settings(self, value): self._settings_plugin = value @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): msg = "Invalid input, str type expected for value, got {} instead".format(type(value)) LOG.error(msg) raise IOError(msg) self._name = value diff --git a/hyppopy/solverfactory.py b/hyppopy/solverfactory.py index d5e3246..bcefca1 100644 --- a/hyppopy/solverfactory.py +++ b/hyppopy/solverfactory.py @@ -1,158 +1,165 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from yapsy.PluginManager import PluginManager from hyppopy.projectmanager import ProjectManager from hyppopy.globals import PLUGIN_DEFAULT_DIR from hyppopy.deepdict import DeepDict from hyppopy.solver import Solver from hyppopy.singleton import * import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) @singleton_object class SolverFactory(metaclass=Singleton): """ This class is responsible for grabbing all plugins from the plugin folder arranging them into a Solver class instances. These Solver class instances can be requested from the factory via the get_solver method. The SolverFactory class is a Singleton class, so try not to instantiate it using SolverFactory(), the consequences will be horrific. Instead use is like a class having static functions only, SolverFactory.method(). """ _plugin_dirs = [] _plugins = {} def __init__(self): self.reset() self.load_plugins() LOG.debug("Solverfactory initialized") def load_plugins(self): """ Load plugin modules from plugin paths """ LOG.debug("load_plugins()") manager = PluginManager() LOG.debug("setPluginPlaces(" + " ".join(map(str, self._plugin_dirs))) manager.setPluginPlaces(self._plugin_dirs) manager.collectPlugins() for plugin in manager.getAllPlugins(): name_elements = plugin.plugin_object.__class__.__name__.split("_") LOG.debug("found plugin " + " ".join(map(str, name_elements))) print("Solverfactory: found plugins " + " ".join(map(str, name_elements))) if len(name_elements) != 2 or ("Solver" not in name_elements and "Settings" not in name_elements): msg = "invalid plugin class naming for class {}, the convention is libname_Solver or libname_Settings.".format(plugin.plugin_object.__class__.__name__) LOG.error(msg) raise NameError(msg) if name_elements[0] not in self._plugins.keys(): self._plugins[name_elements[0]] = Solver() self._plugins[name_elements[0]].name = name_elements[0] if name_elements[1] == "Solver": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] self._plugins[name_elements[0]].solver = obj LOG.info("plugin: {} Solver loaded".format(name_elements[0])) except Exception as e: msg = "failed to instanciate class {}".format(plugin.plugin_object.__class__.__name__) LOG.error(msg) raise ImportError(msg) elif name_elements[1] == "Settings": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] self._plugins[name_elements[0]].settings = obj LOG.info("plugin: {} ParameterSpace loaded".format(name_elements[0])) except Exception as e: msg = "failed to instanciate class {}".format(plugin.plugin_object.__class__.__name__) LOG.error(msg) raise ImportError(msg) else: msg = "failed loading plugin {}, please check if naming conventions are kept!".format(name_elements[0]) LOG.error(msg) raise IOError(msg) if len(self._plugins) == 0: msg = "no plugins found, please check your plugin folder names or your plugin scripts for errors!" LOG.error(msg) raise IOError(msg) def reset(self): """ Reset solver factory """ LOG.debug("reset()") self._plugins = {} self._plugin_dirs = [] self.add_plugin_dir(os.path.abspath(PLUGIN_DEFAULT_DIR)) def add_plugin_dir(self, dir): """ Add plugin directory """ LOG.debug("add_plugin_dir({})".format(dir)) self._plugin_dirs.append(dir) def list_solver(self): """ list all solvers available :return: [list(str)] """ return list(self._plugins.keys()) def from_settings(self, settings): if isinstance(settings, str): if not os.path.isfile(settings): LOG.error("input error, file {} not found!".format(settings)) if not ProjectManager.read_config(settings): LOG.error("failed to read config in ProjectManager!") return None else: if not ProjectManager.set_config(settings): LOG.error("failed to set config in ProjectManager!") return None if not ProjectManager.is_ready(): LOG.error("failed to set config in ProjectManager!") return None try: solver = self.get_solver(ProjectManager.use_plugin) except Exception as e: msg = "failed to create solver, reason {}".format(e) LOG.error(msg) return None return solver - def get_solver(self, name): + def get_solver(self, name=None): """ returns a solver by name tag :param name: [str] solver name :return: [Solver] instance """ + if name is None: + try: + name = ProjectManager.use_plugin + except Exception as e: + msg = "failed to setup solver, no solver specified, check your ProjectManager for the use_plugin value!" + LOG.error(msg) + raise LookupError(msg) if not isinstance(name, str): msg = "Invalid input, str type expected for name, got {} instead".format(type(name)) LOG.error(msg) raise IOError(msg) if name not in self.list_solver(): msg = "failed solver request, a solver called {} is not available, check for typo or if your plugin failed while loading!".format(name) LOG.error(msg) raise LookupError(msg) LOG.debug("get_solver({})".format(name)) return self._plugins[name] diff --git a/hyppopy/solverpluginbase.py b/hyppopy/solverpluginbase.py index 795f5d3..a58dd94 100644 --- a/hyppopy/solverpluginbase.py +++ b/hyppopy/solverpluginbase.py @@ -1,90 +1,85 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import abc import os import time import logging from hyppopy.globals import DEBUGLEVEL from hyppopy.settingspluginbase import SettingsPluginBase LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class SolverPluginBase(object): data = None loss = None _settings = None _name = None - _timer = [] def __init__(self): pass @abc.abstractmethod def loss_function(self, params): raise NotImplementedError('users must define loss_func to use this base class') @abc.abstractmethod def execute_solver(self): raise NotImplementedError('users must define execute_solver to use this base class') @abc.abstractmethod def convert_results(self): raise NotImplementedError('users must define convert_results to use this base class') def set_data(self, data): self.data = data - self._timer = [] def set_loss_function(self, func): self.loss = func def get_results(self): return self.convert_results() def run(self): - start = time.time() self.execute_solver(self.settings.get_hyperparameter()) - end = time.time() - self._timer.append(end - start) @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): msg = "Invalid input, str type expected for value, got {} instead".format(type(value)) LOG.error(msg) raise IOError(msg) self._name = value @property def settings(self): return self._settings @settings.setter def settings(self, value): if not isinstance(value, SettingsPluginBase): msg = "Invalid input, SettingsPluginBase type expected for value, got {} instead".format(type(value)) LOG.error(msg) raise IOError(msg) self._settings = value diff --git a/hyppopy/workflows/dataloader/simpleloader.py b/hyppopy/workflows/dataloader/simpleloader.py index 6760cfd..c2fab98 100644 --- a/hyppopy/workflows/dataloader/simpleloader.py +++ b/hyppopy/workflows/dataloader/simpleloader.py @@ -1,41 +1,41 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import numpy as np import pandas as pd from hyppopy.workflows.dataloader.dataloaderbase import DataLoaderBase -class SimpleDataLoaderBase(DataLoaderBase): +class SimpleDataLoader(DataLoaderBase): def read(self, **kwargs): if kwargs['data_name'].endswith(".npy"): if not kwargs['labels_name'].endswith(".npy"): raise IOError("Expect both data_name and labels_name being of type .npy!") self.data = [np.load(os.path.join(kwargs['path'], kwargs['data_name'])), np.load(os.path.join(kwargs['path'], kwargs['labels_name']))] elif kwargs['data_name'].endswith(".csv"): try: dataset = pd.read_csv(os.path.join(kwargs['path'], kwargs['data_name'])) y = dataset[kwargs['labels_name']].values X = dataset.drop([kwargs['labels_name']], axis=1).values self.data = [X, y] except Exception as e: print("Precondition violation, this usage case expects as data_name a " "csv file and as label_name a name of a column in this csv table!") else: raise NotImplementedError("This combination of data_name and labels_name " "does not yet exist, feel free to add it") diff --git a/hyppopy/workflows/imageregistration_usecase/imageregistration_usecase.py b/hyppopy/workflows/imageregistration_usecase/imageregistration_usecase.py index 993c72a..652393a 100644 --- a/hyppopy/workflows/imageregistration_usecase/imageregistration_usecase.py +++ b/hyppopy/workflows/imageregistration_usecase/imageregistration_usecase.py @@ -1,52 +1,52 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: #------------------------------------------------------ # this needs to be imported, dont remove these from hyppopy.projectmanager import ProjectManager from hyppopy.workflows.workflowbase import WorkflowBase #------------------------------------------------------ # import your external packages from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import cross_val_score # import your custom DataLoader from hyppopy.workflows.dataloader.simpleloader import SimpleDataLoaderBase # This is a dataloader class create your own class imageregistration_usecase(WorkflowBase): - def setup(self): + def setup(self, **kwargs): # here you create your own DataLoader instance dl = SimpleDataLoaderBase() # call the start function of your DataLoader dl.start(path=ProjectManager.data_path, data_name=ProjectManager.data_name, labels_name=ProjectManager.labels_name) # pass the data to the solver self.solver.set_data(dl.data) def blackbox_function(self, data, params): # converting number back to integers is an ugly hack that will be removed in the future if "n_estimators" in params.keys(): params["n_estimators"] = int(round(params["n_estimators"])) # Do your training clf = RandomForestClassifier(**params) # compute your loss loss = -cross_val_score(estimator=clf, X=data[0], y=data[1], cv=3).mean() # return loss return loss diff --git a/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py b/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py index 5c14b6e..632e83a 100644 --- a/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py +++ b/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py @@ -1,38 +1,38 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import cross_val_score from hyppopy.projectmanager import ProjectManager from hyppopy.workflows.workflowbase import WorkflowBase -from hyppopy.workflows.dataloader.simpleloader import SimpleDataLoaderBase +from hyppopy.workflows.dataloader.simpleloader import SimpleDataLoader class randomforest_usecase(WorkflowBase): - def setup(self): - dl = SimpleDataLoaderBase() + def setup(self, **kwargs): + dl = SimpleDataLoader() dl.start(path=ProjectManager.data_path, data_name=ProjectManager.data_name, labels_name=ProjectManager.labels_name) self.solver.set_data(dl.data) def blackbox_function(self, data, params): if "n_estimators" in params.keys(): params["n_estimators"] = int(round(params["n_estimators"])) clf = RandomForestClassifier(**params) return -cross_val_score(estimator=clf, X=data[0], y=data[1], cv=3).mean() diff --git a/hyppopy/workflows/svc_usecase/svc_usecase.py b/hyppopy/workflows/svc_usecase/svc_usecase.py index 4108969..9f7800a 100644 --- a/hyppopy/workflows/svc_usecase/svc_usecase.py +++ b/hyppopy/workflows/svc_usecase/svc_usecase.py @@ -1,38 +1,38 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import numpy as np import pandas as pd from sklearn.svm import SVC from sklearn.model_selection import cross_val_score from hyppopy.projectmanager import ProjectManager from hyppopy.workflows.workflowbase import WorkflowBase -from hyppopy.workflows.dataloader.simpleloader import SimpleDataLoaderBase +from hyppopy.workflows.dataloader.simpleloader import SimpleDataLoader class svc_usecase(WorkflowBase): - def setup(self): - dl = SimpleDataLoaderBase() + def setup(self, **kwargs): + dl = SimpleDataLoader() dl.start(path=ProjectManager.data_path, data_name=ProjectManager.data_name, labels_name=ProjectManager.labels_name) self.solver.set_data(dl.data) def blackbox_function(self, data, params): clf = SVC(**params) return -cross_val_score(estimator=clf, X=data[0], y=data[1], cv=3).mean() diff --git a/hyppopy/workflows/workflowbase.py b/hyppopy/workflows/workflowbase.py index 3b8a8a5..6b14ada 100644 --- a/hyppopy/workflows/workflowbase.py +++ b/hyppopy/workflows/workflowbase.py @@ -1,63 +1,62 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from hyppopy.deepdict import DeepDict from hyppopy.solverfactory import SolverFactory from hyppopy.projectmanager import ProjectManager from hyppopy.globals import SETTINGSCUSTOMPATH, SETTINGSSOLVERPATH import os import abc import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class WorkflowBase(object): def __init__(self): - self._solver = SolverFactory.get_solver(ProjectManager.use_plugin) - self.solver.set_hyperparameters(ProjectManager.get_hyperparameter()) + self._solver = SolverFactory.get_solver() def run(self, save=True): self.setup() self.solver.set_loss_function(self.blackbox_function) self.solver.run() if save: self.solver.save_results() self.test() def get_results(self): return self.solver.get_results() @abc.abstractmethod - def setup(self): + def setup(self, **kwargs): raise NotImplementedError('the user has to implement this function') @abc.abstractmethod def blackbox_function(self): raise NotImplementedError('the user has to implement this function') @abc.abstractmethod def test(self): pass @property def solver(self): return self._solver