diff --git a/hyppopy/settingspluginbase.py b/hyppopy/settingspluginbase.py index a925996..bff67e6 100644 --- a/hyppopy/settingspluginbase.py +++ b/hyppopy/settingspluginbase.py @@ -1,78 +1,84 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import abc import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from hyppopy.globals import SETTINGSPATH from hyppopy.deepdict import DeepDict class SettingsPluginBase(object): _data = None _name = None def __init__(self): self._data = DeepDict() @abc.abstractmethod def convert_parameter(self): raise NotImplementedError('users must define convert_parameter to use this base class') def get_hyperparameter(self): return self.convert_parameter(self.data["hyperparameter"]) def set(self, data): self.data.clear() - self.data.data = data + self.data = data def read(self, fname): self.data.clear() self.data.from_file(fname) def write(self, fname): self.data.to_file(fname) def set_attributes(self, cls): attrs_sec = self.data[SETTINGSPATH] for key, value in attrs_sec.items(): setattr(cls, key, value) @property def data(self): return self._data @data.setter def data(self, value): - return self._data + if isinstance(value, dict): + self._data.data = value + elif isinstance(value, DeepDict): + self._data = value + else: + raise IOError(f"unexpected input type({type(value)}) for data, needs to be of type dict or DeepDict!") + @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead") raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead") self._name = value diff --git a/hyppopy/solverfactory.py b/hyppopy/solverfactory.py index 1de0713..9b39f38 100644 --- a/hyppopy/solverfactory.py +++ b/hyppopy/solverfactory.py @@ -1,153 +1,180 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from yapsy.PluginManager import PluginManager from hyppopy.globals import PLUGIN_DEFAULT_DIR +from hyppopy.deepdict import DeepDict from hyppopy.solver import Solver import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class SolverFactory(object): """ This class is responsible for grabbing all plugins from the plugin folder arranging them into a Solver class instances. These Solver class instances can be requested from the factory via the get_solver method. The SolverFactory class is a Singleton class, so try not to instantiate it using SolverFactory(), the consequences will be horrific. Instead use factory = SolverFactory.instance(). """ _instance = None _locked = True _plugin_dirs = [] _plugins = {} def __init__(self): if self._locked: msg = "!!! seems you used SolverFactory() to get an instance, please don't do that, "\ "it will kill a cute puppy anywhere close to you! SolverFactory is a "\ "Singleton, means please use SolverFactory.instance() instead !!!" LOG.error(msg) raise AssertionError(msg) if SolverFactory._instance is not None: pass else: SolverFactory._instance = self self.reset() self.load_plugins() LOG.debug("initialized") @staticmethod def instance(): """ Singleton instance access :return: [SolverFactory] instance """ SolverFactory._locked = False LOG.debug("instance request") if SolverFactory._instance is None: SolverFactory() SolverFactory._locked = True return SolverFactory._instance def load_plugins(self): """ Load plugin modules from plugin paths """ LOG.debug("load_plugins()") manager = PluginManager() LOG.debug(f"setPluginPlaces(" + " ".join(map(str, self._plugin_dirs))) manager.setPluginPlaces(self._plugin_dirs) manager.collectPlugins() for plugin in manager.getAllPlugins(): name_elements = plugin.plugin_object.__class__.__name__.split("_") LOG.debug("found plugin " + " ".join(map(str, name_elements))) if len(name_elements) != 2 or ("Solver" not in name_elements and "Settings" not in name_elements): LOG.error(f"invalid plugin class naming for class {plugin.plugin_object.__class__.__name__}, the convention is libname_Solver or libname_Settings.") raise NameError(f"invalid plugin class naming for class {plugin.plugin_object.__class__.__name__}, the convention is libname_Solver or libname_Settings.") if name_elements[0] not in self._plugins.keys(): self._plugins[name_elements[0]] = Solver() self._plugins[name_elements[0]].name = name_elements[0] if name_elements[1] == "Solver": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] self._plugins[name_elements[0]].solver = obj LOG.info(f"plugin: {name_elements[0]} Solver loaded") except Exception as e: LOG.error(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") raise ImportError(f"Failed to instanciate class {plugin.plugin_object.__class__.__name__}") elif name_elements[1] == "Settings": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] self._plugins[name_elements[0]].settings = obj LOG.info(f"plugin: {name_elements[0]} ParameterSpace loaded") except Exception as e: LOG.error(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") raise ImportError(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") else: LOG.error(f"failed loading plugin {name_elements[0]}, please check if naming conventions are kept!") raise IOError(f"failed loading plugin {name_elements[0]}!, please check if naming conventions are kept!") if len(self._plugins) == 0: msg = "no plugins found, please check your plugin folder names or your plugin scripts for errors!" LOG.error(msg) raise IOError(msg) def reset(self): """ Reset solver factory """ LOG.debug("reset()") self._plugins = {} self._plugin_dirs = [] self.add_plugin_dir(os.path.abspath(PLUGIN_DEFAULT_DIR)) def add_plugin_dir(self, dir): """ Add plugin directory """ LOG.debug(f"add_plugin_dir({dir})") self._plugin_dirs.append(dir) def list_solver(self): """ list all solvers available :return: [list(str)] """ return list(self._plugins.keys()) + def from_settings(self, settings): + if isinstance(settings, dict): + tmp = DeepDict() + tmp.data = settings + settings = tmp + elif isinstance(settings, str): + if not os.path.isfile(settings): + LOG.warning(f"input error, file {settings} not found!") + settings = DeepDict(settings) + + if isinstance(settings, DeepDict): + if settings.has_section("use_plugin"): + try: + use_plugin = settings["settings/solver/use_plugin"] + except Exception as e: + LOG.warning("wrong settings path for use_plugin option detected, expecting the path settings/solver/use_plugin!") + solver = self.get_solver(use_plugin) + solver.set_parameters(settings) + return solver + LOG.warning("failed to choose a solver, either the config file is missing the section settings/solver/use_plugin, or there might be a typo") + else: + msg = "unknown input error, expected DeepDict, dict or filename!" + LOG.error(msg) + raise IOError(msg) + return None + def get_solver(self, name): """ returns a solver by name tag :param name: [str] solver name :return: [Solver] instance """ if not isinstance(name, str): msg = f"Invalid input, str type expected for name, got {type(name)} instead" LOG.error(msg) raise IOError(msg) if name not in self.list_solver(): msg = f"failed solver request, a solver called {name} is not available, " \ f"check for typo or if your plugin failed while loading!" LOG.error(msg) raise LookupError(msg) LOG.debug(f"get_solver({name})") return self._plugins[name] diff --git a/hyppopy/tests/data/iris_svc_parameter.json b/hyppopy/tests/data/iris_svc_parameter.json index 26572b7..45e2ff2 100644 --- a/hyppopy/tests/data/iris_svc_parameter.json +++ b/hyppopy/tests/data/iris_svc_parameter.json @@ -1,22 +1,23 @@ {"hyperparameter": { "C": { "domain": "uniform", "data": "[0,20]", "type": "float" }, "gamma": { "domain": "uniform", "data": "[0.0001,20.0]", "type": "float" }, "kernel": { "domain": "categorical", "data": "[linear, sigmoid, poly, rbf]", "type": "str" } }, "settings": { "solver": { - "max_iterations": "50" + "max_iterations": "50", + "use_plugin" : "hyperopt" } }} \ No newline at end of file diff --git a/hyppopy/tests/data/iris_svc_parameter.xml b/hyppopy/tests/data/iris_svc_parameter.xml index 1d3217c..7d9670d 100644 --- a/hyppopy/tests/data/iris_svc_parameter.xml +++ b/hyppopy/tests/data/iris_svc_parameter.xml @@ -1,24 +1,25 @@ uniform [0,20] float uniform [0.0001,20.0] float categorical [linear,sigmoid,poly,rbf] str 50 + optunity \ No newline at end of file diff --git a/hyppopy/tests/test_solver_factory.py b/hyppopy/tests/test_solver_factory.py index c5dbe2c..0d87856 100644 --- a/hyppopy/tests/test_solver_factory.py +++ b/hyppopy/tests/test_solver_factory.py @@ -1,73 +1,109 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import unittest from sklearn.svm import SVC from sklearn import datasets from sklearn.model_selection import cross_val_score from sklearn.model_selection import train_test_split from hyppopy.solverfactory import SolverFactory from hyppopy.globals import TESTDATA_DIR -TESTPARAMFILE = os.path.join(TESTDATA_DIR, 'iris_svc_parameter.xml') +TESTPARAMFILE = os.path.join(TESTDATA_DIR, 'iris_svc_parameter') -from hyppopy.deepdict.deepdict import DeepDict +from hyppopy.deepdict import DeepDict class SolverFactoryTestSuite(unittest.TestCase): def setUp(self): - pass + iris = datasets.load_iris() + X, X_test, y, y_test = train_test_split(iris.data, iris.target, test_size=0.1, random_state=42) + self.my_IRIS_dta = [X, y] def test_solver_loading(self): factory = SolverFactory.instance() names = factory.list_solver() self.assertTrue("hyperopt" in names) self.assertTrue("optunity" in names) def test_iris_solver_execution(self): - iris = datasets.load_iris() - X, X_test, y, y_test = train_test_split(iris.data, iris.target, test_size=0.1, random_state=42) - my_IRIS_dta = [X, y] + def my_SVC_loss_func(data, params): clf = SVC(**params) return -cross_val_score(clf, data[0], data[1], cv=3).mean() factory = SolverFactory.instance() solver = factory.get_solver('optunity') - solver.set_data(my_IRIS_dta) - solver.read_parameter(TESTPARAMFILE) + solver.set_data(self.my_IRIS_dta) + solver.read_parameter(TESTPARAMFILE + '.xml') solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() solver = factory.get_solver('hyperopt') - solver.set_data(my_IRIS_dta) - solver.read_parameter(TESTPARAMFILE) + solver.set_data(self.my_IRIS_dta) + solver.read_parameter(TESTPARAMFILE + '.json') + solver.set_loss_function(my_SVC_loss_func) + solver.run() + solver.get_results() + + def test_create_solver_from_settings_directly(self): + factory = SolverFactory.instance() + + def my_SVC_loss_func(data, params): + clf = SVC(**params) + return -cross_val_score(clf, data[0], data[1], cv=3).mean() + + solver = factory.from_settings(TESTPARAMFILE + '.xml') + self.assertEqual(solver.name, "optunity") + solver.set_data(self.my_IRIS_dta) + solver.set_loss_function(my_SVC_loss_func) + solver.run() + solver.get_results() + + solver = factory.from_settings(TESTPARAMFILE + '.json') + self.assertEqual(solver.name, "hyperopt") + solver.set_data(self.my_IRIS_dta) + solver.set_loss_function(my_SVC_loss_func) + solver.run() + solver.get_results() + + dd = DeepDict(TESTPARAMFILE + '.json') + solver = factory.from_settings(dd) + self.assertEqual(solver.name, "hyperopt") + solver.set_data(self.my_IRIS_dta) + solver.set_loss_function(my_SVC_loss_func) + solver.run() + solver.get_results() + + solver = factory.from_settings(dd.data) + self.assertEqual(solver.name, "hyperopt") + solver.set_data(self.my_IRIS_dta) solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() if __name__ == '__main__': unittest.main()