diff --git a/hyppopy/globals.py b/hyppopy/globals.py index fd6f33f..b06a233 100644 --- a/hyppopy/globals.py +++ b/hyppopy/globals.py @@ -1,32 +1,32 @@ # DKFZ # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # -*- coding: utf-8 -*- import os import sys import logging ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT) PLUGIN_DEFAULT_DIR = os.path.join(ROOT, *("hyppopy", "plugins")) TESTDATA_DIR = os.path.join(ROOT, *("hyppopy", "tests", "data")) -SETTINGSPATH = "settings/solver" -CUSTOMPATH = "settings/custom" +SETTINGSSOLVERPATH = "settings/solver" +SETTINGSCUSTOMPATH = "settings/custom" DEEPDICT_XML_ROOT = "hyppopy" LOGFILENAME = os.path.join(ROOT, 'logfile.log') DEBUGLEVEL = logging.DEBUG logging.basicConfig(filename=LOGFILENAME, filemode='w', format='%(levelname)s: %(name)s - %(message)s') diff --git a/hyppopy/settingspluginbase.py b/hyppopy/settingspluginbase.py index 31ee767..07e0d37 100644 --- a/hyppopy/settingspluginbase.py +++ b/hyppopy/settingspluginbase.py @@ -1,87 +1,89 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import abc import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) -from hyppopy.globals import SETTINGSPATH, CUSTOMPATH +from hyppopy.globals import SETTINGSSOLVERPATH, SETTINGSCUSTOMPATH from hyppopy.deepdict import DeepDict class SettingsPluginBase(object): _data = None _name = None def __init__(self): self._data = DeepDict() @abc.abstractmethod def convert_parameter(self): raise NotImplementedError('users must define convert_parameter to use this base class') def get_hyperparameter(self): return self.convert_parameter(self.data["hyperparameter"]) def set(self, data): self.data.clear() self.data = data def read(self, fname): self.data.clear() self.data.from_file(fname) def write(self, fname): self.data.to_file(fname) def set_attributes(self, cls): - attrs_sec = self.data[SETTINGSPATH] - for key, value in attrs_sec.items(): - setattr(cls, key, value) - attrs_sec = self.data[CUSTOMPATH] - for key, value in attrs_sec.items(): - setattr(cls, key, value) + if self.data.has_section(SETTINGSSOLVERPATH.split('/')[-1]): + attrs_sec = self.data[SETTINGSSOLVERPATH] + for key, value in attrs_sec.items(): + setattr(cls, key, value) + if self.data.has_section(SETTINGSCUSTOMPATH.split('/')[-1]): + attrs_sec = self.data[SETTINGSCUSTOMPATH] + for key, value in attrs_sec.items(): + setattr(cls, key, value) @property def data(self): return self._data @data.setter def data(self, value): if isinstance(value, dict): self._data.data = value elif isinstance(value, DeepDict): self._data = value else: raise IOError(f"unexpected input type({type(value)}) for data, needs to be of type dict or DeepDict!") @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead") raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead") self._name = value diff --git a/hyppopy/singleton.py b/hyppopy/singleton.py new file mode 100644 index 0000000..6f61770 --- /dev/null +++ b/hyppopy/singleton.py @@ -0,0 +1,52 @@ +# DKFZ +# +# +# Copyright (c) German Cancer Research Center, +# Division of Medical and Biological Informatics. +# All rights reserved. +# +# This software is distributed WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. +# +# See LICENSE.txt or http://www.mitk.org for details. +# +# Author: Sven Wanner (s.wanner@dkfz.de) + + +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + @classmethod + def __instancecheck__(mcs, instance): + if instance.__class__ is mcs: + return True + else: + return isinstance(instance.__class__, mcs) + + +def singleton_object(cls): + """Class decorator that transforms (and replaces) a class definition (which + must have a Singleton metaclass) with the actual singleton object. Ensures + that the resulting object can still be "instantiated" (i.e., called), + returning the same object. Also ensures the object can be pickled, is + hashable, and has the correct string representation (the name of the + singleton) + """ + assert isinstance(cls, Singleton), cls.__name__ + " must use Singleton metaclass" + + def self_instantiate(self): + return self + + cls.__call__ = self_instantiate + cls.__hash__ = lambda self: hash(cls) + cls.__repr__ = lambda self: cls.__name__ + cls.__reduce__ = lambda self: cls.__name__ + obj = cls() + obj.__name__ = cls.__name__ + return obj diff --git a/hyppopy/solver.py b/hyppopy/solver.py index 14262fd..35de7c8 100644 --- a/hyppopy/solver.py +++ b/hyppopy/solver.py @@ -1,84 +1,86 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class Solver(object): _name = None _solver_plugin = None _settings_plugin = None def __init__(self): pass def set_data(self, data): self.solver.set_data(data) def set_parameters(self, params): self.settings.set(params) self.settings.set_attributes(self.solver) + self.settings.set_attributes(self.settings) def read_parameter(self, fname): self.settings.read(fname) + self.settings.set_attributes(self.solver) self.settings.set_attributes(self.settings) def set_loss_function(self, loss_func): self.solver.set_loss_function(loss_func) def run(self): self.solver.settings = self.settings self.solver.run() def get_results(self): return self.solver.get_results() @property def is_ready(self): return self.solver is not None and self.settings is not None @property def solver(self): return self._solver_plugin @solver.setter def solver(self, value): self._solver_plugin = value @property def settings(self): return self._settings_plugin @settings.setter def settings(self, value): self._settings_plugin = value @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead") raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead") self._name = value diff --git a/hyppopy/solverfactory.py b/hyppopy/solverfactory.py index 9b39f38..cfbe1ed 100644 --- a/hyppopy/solverfactory.py +++ b/hyppopy/solverfactory.py @@ -1,180 +1,156 @@ -# -*- coding: utf-8 -*- -# # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from yapsy.PluginManager import PluginManager from hyppopy.globals import PLUGIN_DEFAULT_DIR from hyppopy.deepdict import DeepDict from hyppopy.solver import Solver +from hyppopy.singleton import * import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) - -class SolverFactory(object): +@singleton_object +class SolverFactory(metaclass=Singleton): """ This class is responsible for grabbing all plugins from the plugin folder arranging them into a Solver class instances. These Solver class instances can be requested from the factory via the get_solver method. The SolverFactory class is a Singleton class, so try not to instantiate it using - SolverFactory(), the consequences will be horrific. Instead use factory = SolverFactory.instance(). + SolverFactory(), the consequences will be horrific. Instead use is like a class having static + functions only, SolverFactory.method(). """ - _instance = None - _locked = True _plugin_dirs = [] _plugins = {} def __init__(self): - if self._locked: - msg = "!!! seems you used SolverFactory() to get an instance, please don't do that, "\ - "it will kill a cute puppy anywhere close to you! SolverFactory is a "\ - "Singleton, means please use SolverFactory.instance() instead !!!" - LOG.error(msg) - raise AssertionError(msg) - if SolverFactory._instance is not None: - pass - else: - SolverFactory._instance = self - self.reset() - self.load_plugins() - LOG.debug("initialized") - - @staticmethod - def instance(): - """ - Singleton instance access - :return: [SolverFactory] instance - """ - SolverFactory._locked = False - LOG.debug("instance request") - if SolverFactory._instance is None: - SolverFactory() - SolverFactory._locked = True - return SolverFactory._instance + print("Solverfactory: I'am alive!") + self.reset() + self.load_plugins() + LOG.debug("Solverfactory initialized") def load_plugins(self): """ Load plugin modules from plugin paths """ LOG.debug("load_plugins()") manager = PluginManager() LOG.debug(f"setPluginPlaces(" + " ".join(map(str, self._plugin_dirs))) manager.setPluginPlaces(self._plugin_dirs) manager.collectPlugins() for plugin in manager.getAllPlugins(): name_elements = plugin.plugin_object.__class__.__name__.split("_") LOG.debug("found plugin " + " ".join(map(str, name_elements))) if len(name_elements) != 2 or ("Solver" not in name_elements and "Settings" not in name_elements): LOG.error(f"invalid plugin class naming for class {plugin.plugin_object.__class__.__name__}, the convention is libname_Solver or libname_Settings.") raise NameError(f"invalid plugin class naming for class {plugin.plugin_object.__class__.__name__}, the convention is libname_Solver or libname_Settings.") if name_elements[0] not in self._plugins.keys(): self._plugins[name_elements[0]] = Solver() self._plugins[name_elements[0]].name = name_elements[0] if name_elements[1] == "Solver": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] self._plugins[name_elements[0]].solver = obj LOG.info(f"plugin: {name_elements[0]} Solver loaded") except Exception as e: LOG.error(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") raise ImportError(f"Failed to instanciate class {plugin.plugin_object.__class__.__name__}") elif name_elements[1] == "Settings": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] self._plugins[name_elements[0]].settings = obj LOG.info(f"plugin: {name_elements[0]} ParameterSpace loaded") except Exception as e: LOG.error(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") raise ImportError(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") else: LOG.error(f"failed loading plugin {name_elements[0]}, please check if naming conventions are kept!") raise IOError(f"failed loading plugin {name_elements[0]}!, please check if naming conventions are kept!") if len(self._plugins) == 0: msg = "no plugins found, please check your plugin folder names or your plugin scripts for errors!" LOG.error(msg) raise IOError(msg) def reset(self): """ Reset solver factory """ LOG.debug("reset()") self._plugins = {} self._plugin_dirs = [] self.add_plugin_dir(os.path.abspath(PLUGIN_DEFAULT_DIR)) def add_plugin_dir(self, dir): """ Add plugin directory """ LOG.debug(f"add_plugin_dir({dir})") self._plugin_dirs.append(dir) def list_solver(self): """ list all solvers available :return: [list(str)] """ return list(self._plugins.keys()) def from_settings(self, settings): if isinstance(settings, dict): tmp = DeepDict() tmp.data = settings settings = tmp elif isinstance(settings, str): if not os.path.isfile(settings): LOG.warning(f"input error, file {settings} not found!") settings = DeepDict(settings) if isinstance(settings, DeepDict): if settings.has_section("use_plugin"): try: use_plugin = settings["settings/solver/use_plugin"] except Exception as e: LOG.warning("wrong settings path for use_plugin option detected, expecting the path settings/solver/use_plugin!") solver = self.get_solver(use_plugin) solver.set_parameters(settings) return solver LOG.warning("failed to choose a solver, either the config file is missing the section settings/solver/use_plugin, or there might be a typo") else: msg = "unknown input error, expected DeepDict, dict or filename!" LOG.error(msg) raise IOError(msg) return None def get_solver(self, name): """ returns a solver by name tag :param name: [str] solver name :return: [Solver] instance """ if not isinstance(name, str): msg = f"Invalid input, str type expected for name, got {type(name)} instead" LOG.error(msg) raise IOError(msg) if name not in self.list_solver(): msg = f"failed solver request, a solver called {name} is not available, " \ f"check for typo or if your plugin failed while loading!" LOG.error(msg) raise LookupError(msg) LOG.debug(f"get_solver({name})") return self._plugins[name] diff --git a/hyppopy/tests/data/Titanic/svc_config.xml b/hyppopy/tests/data/Titanic/svc_config.xml index 5c2b275..b26c191 100644 --- a/hyppopy/tests/data/Titanic/svc_config.xml +++ b/hyppopy/tests/data/Titanic/svc_config.xml @@ -1,34 +1,34 @@ uniform [0,20] float uniform [0.0001,20.0] float categorical [linear,sigmoid,poly,rbf] str categorical [ovo,ovr] str 3 - hyperopt + optunity train_cleaned.csv Survived \ No newline at end of file diff --git a/hyppopy/tests/test_solver_factory.py b/hyppopy/tests/test_solver_factory.py index 0d87856..177a27a 100644 --- a/hyppopy/tests/test_solver_factory.py +++ b/hyppopy/tests/test_solver_factory.py @@ -1,109 +1,104 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import unittest from sklearn.svm import SVC from sklearn import datasets from sklearn.model_selection import cross_val_score from sklearn.model_selection import train_test_split from hyppopy.solverfactory import SolverFactory from hyppopy.globals import TESTDATA_DIR TESTPARAMFILE = os.path.join(TESTDATA_DIR, 'iris_svc_parameter') from hyppopy.deepdict import DeepDict class SolverFactoryTestSuite(unittest.TestCase): def setUp(self): iris = datasets.load_iris() X, X_test, y, y_test = train_test_split(iris.data, iris.target, test_size=0.1, random_state=42) self.my_IRIS_dta = [X, y] def test_solver_loading(self): - factory = SolverFactory.instance() - names = factory.list_solver() + names = SolverFactory.list_solver() self.assertTrue("hyperopt" in names) self.assertTrue("optunity" in names) def test_iris_solver_execution(self): - def my_SVC_loss_func(data, params): clf = SVC(**params) return -cross_val_score(clf, data[0], data[1], cv=3).mean() - factory = SolverFactory.instance() - - solver = factory.get_solver('optunity') + solver = SolverFactory.get_solver('optunity') solver.set_data(self.my_IRIS_dta) solver.read_parameter(TESTPARAMFILE + '.xml') solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() - solver = factory.get_solver('hyperopt') + solver = SolverFactory.get_solver('hyperopt') solver.set_data(self.my_IRIS_dta) solver.read_parameter(TESTPARAMFILE + '.json') solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() def test_create_solver_from_settings_directly(self): - factory = SolverFactory.instance() def my_SVC_loss_func(data, params): clf = SVC(**params) return -cross_val_score(clf, data[0], data[1], cv=3).mean() - solver = factory.from_settings(TESTPARAMFILE + '.xml') + solver = SolverFactory.from_settings(TESTPARAMFILE + '.xml') self.assertEqual(solver.name, "optunity") solver.set_data(self.my_IRIS_dta) solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() - solver = factory.from_settings(TESTPARAMFILE + '.json') + solver = SolverFactory.from_settings(TESTPARAMFILE + '.json') self.assertEqual(solver.name, "hyperopt") solver.set_data(self.my_IRIS_dta) solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() dd = DeepDict(TESTPARAMFILE + '.json') - solver = factory.from_settings(dd) + solver = SolverFactory.from_settings(dd) self.assertEqual(solver.name, "hyperopt") solver.set_data(self.my_IRIS_dta) solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() - solver = factory.from_settings(dd.data) + solver = SolverFactory.from_settings(dd.data) self.assertEqual(solver.name, "hyperopt") solver.set_data(self.my_IRIS_dta) solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() if __name__ == '__main__': unittest.main() diff --git a/hyppopy/tests/test_workflows.py b/hyppopy/tests/test_workflows.py index ac64601..8071a1a 100644 --- a/hyppopy/tests/test_workflows.py +++ b/hyppopy/tests/test_workflows.py @@ -1,130 +1,130 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import unittest from hyppopy.globals import TESTDATA_DIR IRIS_DATA = os.path.join(TESTDATA_DIR, 'Iris') TITANIC_DATA = os.path.join(TESTDATA_DIR, 'Titanic') from hyppopy.workflows.svc_usecase.svc_usecase import svc_usecase from hyppopy.workflows.randomforest_usecase.randomforest_usecase import randomforest_usecase class Args(object): def __init__(self): pass def set_arg(self, name, value): setattr(self, name, value) class WorkflowTestSuite(unittest.TestCase): def setUp(self): self.results = [] def test_workflow_svc_on_iris_from_xml(self): svc_args_xml = Args() svc_args_xml.set_arg('plugin', '') svc_args_xml.set_arg('data', IRIS_DATA) svc_args_xml.set_arg('config', os.path.join(IRIS_DATA, 'svc_config.xml')) uc = svc_usecase(svc_args_xml) uc.run() self.results.append(uc.get_results()) self.assertTrue(uc.get_results().find("Solution") != -1) def test_workflow_rf_on_iris_from_xml(self): rf_args_xml = Args() rf_args_xml.set_arg('plugin', '') rf_args_xml.set_arg('data', IRIS_DATA) rf_args_xml.set_arg('config', os.path.join(IRIS_DATA, 'rf_config.xml')) uc = svc_usecase(rf_args_xml) uc.run() self.results.append(uc.get_results()) self.assertTrue(uc.get_results().find("Solution") != -1) def test_workflow_svc_on_iris_from_json(self): svc_args_json = Args() svc_args_json.set_arg('plugin', '') svc_args_json.set_arg('data', IRIS_DATA) svc_args_json.set_arg('config', os.path.join(IRIS_DATA, 'svc_config.json')) uc = svc_usecase(svc_args_json) uc.run() self.results.append(uc.get_results()) self.assertTrue(uc.get_results().find("Solution") != -1) def test_workflow_rf_on_iris_from_json(self): rf_args_json = Args() rf_args_json.set_arg('plugin', '') rf_args_json.set_arg('data', IRIS_DATA) rf_args_json.set_arg('config', os.path.join(IRIS_DATA, 'rf_config.json')) - uc = svc_usecase(rf_args_json) + uc = randomforest_usecase(rf_args_json) uc.run() self.results.append(uc.get_results()) self.assertTrue(uc.get_results().find("Solution") != -1) - # def test_workflow_svc_on_titanic_from_xml(self): - # svc_args_xml = Args() - # svc_args_xml.set_arg('plugin', '') - # svc_args_xml.set_arg('data', TITANIC_DATA) - # svc_args_xml.set_arg('config', os.path.join(TITANIC_DATA, 'svc_config.xml')) - # uc = svc_usecase(svc_args_xml) - # uc.run() - # self.results.append(uc.get_results()) - # self.assertTrue(uc.get_results().find("Solution") != -1) + def test_workflow_svc_on_titanic_from_xml(self): + svc_args_xml = Args() + svc_args_xml.set_arg('plugin', '') + svc_args_xml.set_arg('data', TITANIC_DATA) + svc_args_xml.set_arg('config', os.path.join(TITANIC_DATA, 'svc_config.xml')) + uc = svc_usecase(svc_args_xml) + uc.run() + self.results.append(uc.get_results()) + self.assertTrue(uc.get_results().find("Solution") != -1) def test_workflow_rf_on_titanic_from_xml(self): rf_args_xml = Args() rf_args_xml.set_arg('plugin', '') rf_args_xml.set_arg('data', TITANIC_DATA) rf_args_xml.set_arg('config', os.path.join(TITANIC_DATA, 'rf_config.xml')) - uc = svc_usecase(rf_args_xml) + uc = randomforest_usecase(rf_args_xml) uc.run() self.results.append(uc.get_results()) self.assertTrue(uc.get_results().find("Solution") != -1) - # def test_workflow_svc_on_titanic_from_json(self): - # svc_args_json = Args() - # svc_args_json.set_arg('plugin', '') - # svc_args_json.set_arg('data', TITANIC_DATA) - # svc_args_json.set_arg('config', os.path.join(TITANIC_DATA, 'svc_config.json')) - # uc = svc_usecase(svc_args_json) - # uc.run() - # self.results.append(uc.get_results()) - # self.assertTrue(uc.get_results().find("Solution") != -1) + def test_workflow_svc_on_titanic_from_json(self): + svc_args_json = Args() + svc_args_json.set_arg('plugin', '') + svc_args_json.set_arg('data', TITANIC_DATA) + svc_args_json.set_arg('config', os.path.join(TITANIC_DATA, 'svc_config.json')) + uc = svc_usecase(svc_args_json) + uc.run() + self.results.append(uc.get_results()) + self.assertTrue(uc.get_results().find("Solution") != -1) def test_workflow_rf_on_titanic_from_json(self): rf_args_json = Args() rf_args_json.set_arg('plugin', '') rf_args_json.set_arg('data', TITANIC_DATA) rf_args_json.set_arg('config', os.path.join(TITANIC_DATA, 'rf_config.json')) - uc = svc_usecase(rf_args_json) + uc = randomforest_usecase(rf_args_json) uc.run() self.results.append(uc.get_results()) self.assertTrue(uc.get_results().find("Solution") != -1) def tearDown(self): print("") for r in self.results: print(r) if __name__ == '__main__': unittest.main() diff --git a/hyppopy/workflowbase.py b/hyppopy/workflowbase.py index 548c0e9..fea491b 100644 --- a/hyppopy/workflowbase.py +++ b/hyppopy/workflowbase.py @@ -1,77 +1,76 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) -import hyppopy.solverfactory as sfac +from hyppopy.solverfactory import SolverFactory from hyppopy.deepdict import DeepDict -from hyppopy.globals import SETTINGSPATH +from hyppopy.globals import SETTINGSCUSTOMPATH, SETTINGSSOLVERPATH import os import abc import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class Workflow(object): _solver = None _args = None def __init__(self, args): self._args = args - factory = sfac.SolverFactory.instance() if args.plugin is None or args.plugin == '': dd = DeepDict(args.config) ppath = "use_plugin" if not dd.has_section(ppath): msg = f"invalid config file, missing section {ppath}" LOG.error(msg) raise LookupError(msg) - plugin = dd[SETTINGSPATH+'/'+ppath] + plugin = dd[SETTINGSSOLVERPATH+'/'+ppath] else: plugin = args.plugin - self._solver = factory.get_solver(plugin) + self._solver = SolverFactory.get_solver(plugin) self.solver.read_parameter(args.config) def run(self): self.setup() self.solver.set_loss_function(self.blackbox_function) self.solver.run() self.test() def get_results(self): return self.solver.get_results() @abc.abstractmethod def setup(self): raise NotImplementedError('the user has to implement this function') @abc.abstractmethod def blackbox_function(self): raise NotImplementedError('the user has to implement this function') @abc.abstractmethod def test(self): pass @property def solver(self): return self._solver @property def args(self): return self._args