diff --git a/bin/hyppopy_exe.py b/bin/hyppopy_exe.py index 10f41e7..8bf5103 100644 --- a/bin/hyppopy_exe.py +++ b/bin/hyppopy_exe.py @@ -1,86 +1,69 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) +from hyppopy.projectmanager import ProjectManager from hyppopy.workflows.unet_usecase.unet_usecase import unet_usecase from hyppopy.workflows.svc_usecase.svc_usecase import svc_usecase from hyppopy.workflows.randomforest_usecase.randomforest_usecase import randomforest_usecase import os import sys import argparse -import hyppopy.solverfactory as sfac - - -solver_factory = sfac.SolverFactory.instance() def print_warning(msg): print("\n!!!!! WARNING !!!!!") print(msg) sys.exit() def args_check(args): if not args.workflow: print_warning("No workflow specified, check --help") if not args.config: print_warning("Missing config parameter, check --help") - if not args.data: - print_warning("Missing data parameter, check --help") - if not os.path.isdir(args.data): - print_warning("Couldn't find data path, please check your input --data") - if not os.path.isfile(args.config): - tmp = os.path.join(args.data, args.config) - if not os.path.isfile(tmp): - print_warning("Couldn't find the config file, please check your input --config") - args.config = tmp + print_warning(f"Couldn't find configfile ({args.config}), please check your input --config") if __name__ == "__main__": parser = argparse.ArgumentParser(description='UNet Hyppopy UseCase Example Optimization.') parser.add_argument('-w', '--workflow', type=str, help='workflow to be executed') - parser.add_argument('-p', '--plugin', type=str, default='', - help='plugin to be used default=[hyperopt], optunity') - parser.add_argument('-d', '--data', type=str, help='training data path') parser.add_argument('-c', '--config', type=str, help='config filename, .xml or .json formats are supported.' 'pass a full path filename or the filename only if the' 'configfile is in the data folder') - parser.add_argument('-i', '--iterations', type=int, default=0, - help='number of iterations, default=[0] if set to 0 the value set via configfile is used, ' - 'otherwise the configfile value will be overwritten') args = parser.parse_args() - args_check(args) + ProjectManager.read_config(args.config) + if args.workflow == "svc_usecase": - uc = svc_usecase(args) + uc = svc_usecase() elif args.workflow == "randomforest_usecase": - uc = randomforest_usecase(args) + uc = randomforest_usecase() elif args.workflow == "unet_usecase": - uc = unet_usecase(args) + uc = unet_usecase() else: print(f"No workflow called {args.workflow} found!") sys.exit() uc.run() print(uc.get_results()) diff --git a/hyppopy/deepdict.py b/hyppopy/deepdict.py index a149797..f8e84d5 100644 --- a/hyppopy/deepdict.py +++ b/hyppopy/deepdict.py @@ -1,435 +1,437 @@ -# -*- coding: utf-8 -*- -# # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import re import json import types import pprint import xmltodict from dicttoxml import dicttoxml from collections import OrderedDict import logging LOG = logging.getLogger('hyppopy') from hyppopy.globals import DEEPDICT_XML_ROOT def convert_ordered2std_dict(obj): """ Helper function converting an OrderedDict into a standard lib dict. :param obj: [OrderedDict] """ for key, value in obj.items(): if isinstance(value, OrderedDict): obj[key] = dict(obj[key]) convert_ordered2std_dict(obj[key]) def check_dir_existance(dirname): """ Helper function to check if a directory exists, creating it if not. :param dirname: [str] full path of the directory to check """ if not os.path.exists(dirname): os.mkdir(dirname) class DeepDict(object): """ The DeepDict class represents a nested dictionary with additional functionality compared to a standard lib dict. The data can be accessed and changed vie a pathlike access and dumped or read to .json/.xml files. Initializing instances using defaults creates an empty DeepDict. Using in_data enables to initialize the object instance with data, where in_data can be a dict, or a filepath to a json or xml file. Using path sep the appearance of path passing can be changed, a default data access via path would look like my_dd['target/section/path'] with path_sep='.' like so my_dd['target.section.path'] :param in_data: [dict] or [str], input dict or filename :param path_sep: [str] path separator character """ _data = None _sep = "/" def __init__(self, in_data=None, path_sep="/"): self.clear() self._sep = path_sep LOG.debug(f"path separator is: {self._sep}") if in_data is not None: if isinstance(in_data, str): self.from_file(in_data) elif isinstance(in_data, dict): self.data = in_data def __str__(self): """ Enables print output for class instances, printing the instance data dict using pretty print :return: [str] """ return pprint.pformat(self.data) def __eq__(self, other): """ Overloads the == operator comparing the instance data dictionaries for equality :param other: [DeepDict] rhs :return: [bool] """ return self.data == other.data def __getitem__(self, path): """ Overloads the return of the [] operator for data access. This enables access the DeepDict instance like so: my_dd['target/section/path'] or my_dd[['target','section','path']] :param path: [str] or [list(str)], the path to the target data structure level/content :return: [object] """ return DeepDict.get_from_path(self.data, path, self.sep) def __setitem__(self, path, value=None): """ Overloads the setter for the [] operator for data assignment. :param path: [str] or [list(str)], the path to the target data structure level/content :param value: [object] rhs assignment object """ if isinstance(path, str): path = path.split(self.sep) if not isinstance(path, list) or isinstance(path, tuple): raise IOError("Input Error, expect list[str] type for path") if len(path) < 1: raise IOError("Input Error, missing section strings") if not path[0] in self._data.keys(): if value is not None and len(path) == 1: self._data[path[0]] = value else: self._data[path[0]] = {} tmp = self._data[path[0]] path.pop(0) while True: if len(path) == 0: break if path[0] not in tmp.keys(): if value is not None and len(path) == 1: tmp[path[0]] = value else: tmp[path[0]] = {} tmp = tmp[path[0]] else: tmp = tmp[path[0]] path.pop(0) def __len__(self): return len(self._data) def items(self): return self.data.items() def clear(self): """ clears the instance data """ LOG.debug("clear()") self._data = {} def from_file(self, fname): """ Loads data from file. Currently implemented .json and .xml file reader. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if fname.endswith(".json"): self.read_json(fname) elif fname.endswith(".xml"): self.read_xml(fname) else: LOG.error("Unknown filetype, expect [.json, .xml]") raise NotImplementedError("Unknown filetype, expect [.json, .xml]") def read_json(self, fname): """ Read json file :param fname: [str] input filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if not os.path.isfile(fname): raise IOError(f"File {fname} not found!") LOG.debug(f"read_json({fname})") try: with open(fname, "r") as read_file: self._data = json.load(read_file) DeepDict.value_traverse(self.data, callback=DeepDict.parse_type) except Exception as e: LOG.error(f"Error while reading json file {fname} or while converting types") raise IOError("Error while reading json file {fname} or while converting types") def read_xml(self, fname): """ Read xml file :param fname: [str] input filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if not os.path.isfile(fname): raise IOError(f"File {fname} not found!") LOG.debug(f"read_xml({fname})") try: with open(fname, "r") as read_file: xml = "".join(read_file.readlines()) self._data = xmltodict.parse(xml, attr_prefix='') DeepDict.value_traverse(self.data, callback=DeepDict.parse_type) except Exception as e: msg = f"Error while reading xml file {fname} or while converting types" LOG.error(msg) raise IOError(msg) # if written with DeepDict, the xml contains a root node called # deepdict which should beremoved for consistency reasons if DEEPDICT_XML_ROOT in self._data.keys(): self._data = self._data[DEEPDICT_XML_ROOT] self._data = dict(self.data) # convert the orderes dict structure to a default dict for consistency reasons convert_ordered2std_dict(self.data) def to_file(self, fname): """ Write to file, type is determined by checking the filename ending. Currently implemented is writing to json and to xml. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if fname.endswith(".json"): self.write_json(fname) elif fname.endswith(".xml"): self.write_xml(fname) else: LOG.error(f"Unknown filetype, expect [.json, .xml]") raise NotImplementedError("Unknown filetype, expect [.json, .xml]") def write_json(self, fname): """ Dump data to json file. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") check_dir_existance(os.path.dirname(fname)) try: LOG.debug(f"write_json({fname})") with open(fname, "w") as write_file: json.dump(self.data, write_file) except Exception as e: LOG.error(f"Failed dumping to json file: {fname}") raise e def write_xml(self, fname): """ Dump data to json file. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") check_dir_existance(os.path.dirname(fname)) xml = dicttoxml(self.data, custom_root=DEEPDICT_XML_ROOT, attr_type=False) LOG.debug(f"write_xml({fname})") try: with open(fname, "w") as write_file: write_file.write(xml.decode("utf-8")) except Exception as e: LOG.error(f"Failed dumping to xml file: {fname}") raise e def has_section(self, section): return DeepDict.has_key(self.data, section) @staticmethod def get_from_path(data, path, sep="/"): """ Implements a nested dict access via a path like string like so path='target/section/path' which is equivalent to my_dict['target']['section']['path']. :param data: [dict] input dictionary :param path: [str] pathlike string :param sep: [str] path separator, default='/' :return: [object] """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for data") raise IOError("Input Error, expect dict type for data") if isinstance(path, str): path = path.split(sep) if not isinstance(path, list) or isinstance(path, tuple): LOG.error(f"Input Error, expect list[str] type for path: {path}") raise IOError("Input Error, expect list[str] type for path") if not DeepDict.has_key(data, path[-1]): LOG.error(f"Input Error, section {path[-1]} does not exist in dictionary") raise IOError(f"Input Error, section {path[-1]} does not exist in dictionary") try: for k in path: data = data[k] except Exception as e: LOG.error(f"Failed retrieving data from path {path} due to {e}") raise LookupError(f"Failed retrieving data from path {path} due to {e}") return data @staticmethod def has_key(data, section, already_found=False): """ Checks if input dictionary has a key called section. The already_found parameter is for internal recursion checks. :param data: [dict] input dictionary :param section: [str] key string to search for :param already_found: recursion criteria check :return: [bool] section found """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for obj") raise IOError("Input Error, expect dict type for obj") if not isinstance(section, str): LOG.error(f"Input Error, expect dict type for obj {section}") raise IOError(f"Input Error, expect dict type for obj {section}") if already_found: return True found = False for key, value in data.items(): if key == section: found = True if isinstance(value, dict): found = DeepDict.has_key(data[key], section, found) return found @staticmethod def value_traverse(data, callback=None): """ Dictionary filter function, walks through the input dict (obj) calling the callback function for each value. The callback function return is assigned the the corresponding dict value. :param data: [dict] input dictionary :param callback: """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for obj") raise IOError("Input Error, expect dict type for obj") if not isinstance(callback, types.FunctionType): LOG.error("Input Error, expect function type for callback") raise IOError("Input Error, expect function type for callback") for key, value in data.items(): if isinstance(value, dict): DeepDict.value_traverse(data[key], callback) else: data[key] = callback(value) def transfer_attrs(self, cls, target_section): - def set(item): + items_set = [] + + def set_item(item): + items_set.append(item[0]) setattr(cls, item[0], item[1]) - DeepDict.sectionconstraint_item_traverse(self.data, target_section, callback=set, section=None) + DeepDict.sectionconstraint_item_traverse(self.data, target_section, callback=set_item, section=None) + return items_set @staticmethod def sectionconstraint_item_traverse(data, target_section, callback=None, section=None): """ Dictionary filter function, walks through the input dict (obj) calling the callback function for each item. The callback function then is called with the key value pair as tuple input but only for the target section. :param data: [dict] input dictionary :param callback: """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for obj") raise IOError("Input Error, expect dict type for obj") if not isinstance(callback, types.FunctionType): LOG.error("Input Error, expect function type for callback") raise IOError("Input Error, expect function type for callback") for key, value in data.items(): if isinstance(value, dict): DeepDict.sectionconstraint_item_traverse(data[key], target_section, callback, key) else: if target_section == section: callback((key, value)) @staticmethod def item_traverse(data, callback=None): """ Dictionary filter function, walks through the input dict (obj) calling the callback function for each item. The callback function then is called with the key value pair as tuple input. :param data: [dict] input dictionary :param callback: """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for obj") raise IOError("Input Error, expect dict type for obj") if not isinstance(callback, types.FunctionType): LOG.error("Input Error, expect function type for callback") raise IOError("Input Error, expect function type for callback") for key, value in data.items(): if isinstance(value, dict): DeepDict.value_traverse(data[key], callback) else: callback((key, value)) @staticmethod def parse_type(string): """ Type convert input string to float, int, list, tuple or string :param string: [str] input string :return: [T] converted output """ try: a = float(string) try: b = int(string) except ValueError: return float(string) if a == b: return b return a except ValueError: if string.startswith("[") and string.endswith("]"): string = re.sub(' ', '', string) elements = string[1:-1].split(",") li = [] for e in elements: li.append(DeepDict.parse_type(e)) return li elif string.startswith("(") and string.endswith(")"): elements = string[1:-1].split(",") li = [] for e in elements: li.append(DeepDict.parse_type(e)) return tuple(li) return string @property def data(self): return self._data @data.setter def data(self, value): if not isinstance(value, dict): LOG.error(f"Input Error, expect dict type for value, but got {type(value)}") raise IOError(f"Input Error, expect dict type for value, but got {type(value)}") self.clear() self._data = value @property def sep(self): return self._sep @sep.setter def sep(self, value): if not isinstance(value, str): LOG.error(f"Input Error, expect str type for value, but got {type(value)}") raise IOError(f"Input Error, expect str type for value, but got {type(value)}") self._sep = value diff --git a/hyppopy/globals.py b/hyppopy/globals.py index b06a233..922823c 100644 --- a/hyppopy/globals.py +++ b/hyppopy/globals.py @@ -1,32 +1,30 @@ # DKFZ # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. -# -*- coding: utf-8 -*- - import os import sys import logging ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT) PLUGIN_DEFAULT_DIR = os.path.join(ROOT, *("hyppopy", "plugins")) TESTDATA_DIR = os.path.join(ROOT, *("hyppopy", "tests", "data")) SETTINGSSOLVERPATH = "settings/solver" SETTINGSCUSTOMPATH = "settings/custom" DEEPDICT_XML_ROOT = "hyppopy" LOGFILENAME = os.path.join(ROOT, 'logfile.log') DEBUGLEVEL = logging.DEBUG logging.basicConfig(filename=LOGFILENAME, filemode='w', format='%(levelname)s: %(name)s - %(message)s') diff --git a/hyppopy/projectmanager.py b/hyppopy/projectmanager.py index 24eed29..a85ece0 100644 --- a/hyppopy/projectmanager.py +++ b/hyppopy/projectmanager.py @@ -1,55 +1,111 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from hyppopy.singleton import * from hyppopy.deepdict import DeepDict from hyppopy.globals import SETTINGSCUSTOMPATH, SETTINGSSOLVERPATH import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) @singleton_object class ProjectManager(metaclass=Singleton): def __init__(self): self.configfilename = None self.config = None + self._extmembers = [] + + def clear(self): + self.configfilename = None + self.config = None + self.remove_externals() + + def is_ready(self): + return self.config is not None + + def remove_externals(self): + for added in self._extmembers: + if added in self.__dict__.keys(): + del self.__dict__[added] + self._extmembers = [] + + def get_hyperparameter(self): + return self.config["hyperparameter"] def test_config(self): - #TODO test the config structure to fullfill the needs, throwing useful error is not + if not isinstance(self.config, DeepDict): + msg = f"test_config failed, config is not of type DeepDict" + LOG.error(msg) + return False + sections = ["hyperparameter"] + sections += SETTINGSSOLVERPATH.split("/") + sections += SETTINGSCUSTOMPATH.split("/") + for sec in sections: + if not self.config.has_section(sec): + msg = f"test_config failed, config has no section {sec}" + LOG.error(msg) + return False + return True + + def set_config(self, config): + self.clear() + if isinstance(config, dict): + self.config = DeepDict() + self.config.data = config + elif isinstance(config, DeepDict): + self.config = config + else: + msg = f"unknown type ({type(config)}) for config passed, expected dict or DeepDict" + LOG.error(msg) + raise IOError(msg) + + if not self.test_config(): + self.clear() + return False + + try: + self._extmembers += self.config.transfer_attrs(self, SETTINGSCUSTOMPATH.split("/")[-1]) + self._extmembers += self.config.transfer_attrs(self, SETTINGSSOLVERPATH.split("/")[-1]) + except Exception as e: + msg = f"transfering custom section as class attributes failed, " \ + f"is the config path to your custom section correct? {SETTINGSCUSTOMPATH}. Exception {e}" + LOG.error(msg) + raise LookupError(msg) + return True def read_config(self, configfile): + self.clear() self.configfilename = configfile self.config = DeepDict(configfile) if not self.test_config(): - self.configfilename = None - self.config = None + self.clear() return False try: - self.config.transfer_attrs(self, SETTINGSCUSTOMPATH.split("/")[-1]) - self.config.transfer_attrs(self, SETTINGSSOLVERPATH.split("/")[-1]) + self._extmembers += self.config.transfer_attrs(self, SETTINGSCUSTOMPATH.split("/")[-1]) + self._extmembers += self.config.transfer_attrs(self, SETTINGSSOLVERPATH.split("/")[-1]) except Exception as e: msg = f"transfering custom section as class attributes failed, " \ f"is the config path to your custom section correct? {SETTINGSCUSTOMPATH}. Exception {e}" LOG.error(msg) raise LookupError(msg) return True diff --git a/hyppopy/settingspluginbase.py b/hyppopy/settingspluginbase.py index 3d8d49a..5403acb 100644 --- a/hyppopy/settingspluginbase.py +++ b/hyppopy/settingspluginbase.py @@ -1,77 +1,78 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import abc import os +import copy import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from hyppopy.globals import SETTINGSSOLVERPATH, SETTINGSCUSTOMPATH from hyppopy.deepdict import DeepDict class SettingsPluginBase(object): _data = None _name = None def __init__(self): self._data = {} @abc.abstractmethod def convert_parameter(self): raise NotImplementedError('users must define convert_parameter to use this base class') def get_hyperparameter(self): return self.convert_parameter(self.data) - def set(self, data): + def set_hyperparameter(self, input_data): self.data.clear() - self.data = data + self.data = copy.deepcopy(input_data) def read(self, fname): self.data.clear() self.data.from_file(fname) def write(self, fname): self.data.to_file(fname) @property def data(self): return self._data @data.setter def data(self, value): if isinstance(value, dict): self._data = value elif isinstance(value, DeepDict): self._data = value.data else: raise IOError(f"unexpected input type({type(value)}) for data, needs to be of type dict or DeepDict!") @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead") raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead") self._name = value diff --git a/hyppopy/solver.py b/hyppopy/solver.py index 1ae9740..de02af9 100644 --- a/hyppopy/solver.py +++ b/hyppopy/solver.py @@ -1,77 +1,85 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) + +from hyppopy.projectmanager import ProjectManager + import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class Solver(object): _name = None _solver_plugin = None _settings_plugin = None def __init__(self): pass def set_data(self, data): self.solver.set_data(data) def set_hyperparameters(self, params): - self.settings.set(params) + self.settings.set_hyperparameter(params) def set_loss_function(self, loss_func): self.solver.set_loss_function(loss_func) def run(self): + if not ProjectManager.is_ready(): + LOG.error("No config data found to initialize PluginSetting object") + raise IOError("No config data found to initialize PluginSetting object") + hyps = ProjectManager.get_hyperparameter() + self.settings.set_hyperparameter(hyps) self.solver.settings = self.settings self.solver.run() def get_results(self): return self.solver.get_results() @property def is_ready(self): return self.solver is not None and self.settings is not None @property def solver(self): return self._solver_plugin @solver.setter def solver(self, value): self._solver_plugin = value @property def settings(self): return self._settings_plugin @settings.setter def settings(self, value): self._settings_plugin = value @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead") raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead") self._name = value diff --git a/hyppopy/solverfactory.py b/hyppopy/solverfactory.py index 697c33c..fad3f74 100644 --- a/hyppopy/solverfactory.py +++ b/hyppopy/solverfactory.py @@ -1,156 +1,155 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from yapsy.PluginManager import PluginManager +from hyppopy.projectmanager import ProjectManager from hyppopy.globals import PLUGIN_DEFAULT_DIR from hyppopy.deepdict import DeepDict from hyppopy.solver import Solver from hyppopy.singleton import * import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) @singleton_object class SolverFactory(metaclass=Singleton): """ This class is responsible for grabbing all plugins from the plugin folder arranging them into a Solver class instances. These Solver class instances can be requested from the factory via the get_solver method. The SolverFactory class is a Singleton class, so try not to instantiate it using SolverFactory(), the consequences will be horrific. Instead use is like a class having static functions only, SolverFactory.method(). """ _plugin_dirs = [] _plugins = {} def __init__(self): print("Solverfactory: I'am alive!") self.reset() self.load_plugins() LOG.debug("Solverfactory initialized") def load_plugins(self): """ Load plugin modules from plugin paths """ LOG.debug("load_plugins()") manager = PluginManager() LOG.debug(f"setPluginPlaces(" + " ".join(map(str, self._plugin_dirs))) manager.setPluginPlaces(self._plugin_dirs) manager.collectPlugins() for plugin in manager.getAllPlugins(): name_elements = plugin.plugin_object.__class__.__name__.split("_") LOG.debug("found plugin " + " ".join(map(str, name_elements))) if len(name_elements) != 2 or ("Solver" not in name_elements and "Settings" not in name_elements): LOG.error(f"invalid plugin class naming for class {plugin.plugin_object.__class__.__name__}, the convention is libname_Solver or libname_Settings.") raise NameError(f"invalid plugin class naming for class {plugin.plugin_object.__class__.__name__}, the convention is libname_Solver or libname_Settings.") if name_elements[0] not in self._plugins.keys(): self._plugins[name_elements[0]] = Solver() self._plugins[name_elements[0]].name = name_elements[0] if name_elements[1] == "Solver": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] self._plugins[name_elements[0]].solver = obj LOG.info(f"plugin: {name_elements[0]} Solver loaded") except Exception as e: LOG.error(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") raise ImportError(f"Failed to instanciate class {plugin.plugin_object.__class__.__name__}") elif name_elements[1] == "Settings": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] self._plugins[name_elements[0]].settings = obj LOG.info(f"plugin: {name_elements[0]} ParameterSpace loaded") except Exception as e: LOG.error(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") raise ImportError(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") else: LOG.error(f"failed loading plugin {name_elements[0]}, please check if naming conventions are kept!") raise IOError(f"failed loading plugin {name_elements[0]}!, please check if naming conventions are kept!") if len(self._plugins) == 0: msg = "no plugins found, please check your plugin folder names or your plugin scripts for errors!" LOG.error(msg) raise IOError(msg) def reset(self): """ Reset solver factory """ LOG.debug("reset()") self._plugins = {} self._plugin_dirs = [] self.add_plugin_dir(os.path.abspath(PLUGIN_DEFAULT_DIR)) def add_plugin_dir(self, dir): """ Add plugin directory """ LOG.debug(f"add_plugin_dir({dir})") self._plugin_dirs.append(dir) def list_solver(self): """ list all solvers available :return: [list(str)] """ return list(self._plugins.keys()) def from_settings(self, settings): - if isinstance(settings, dict): - tmp = DeepDict() - tmp.data = settings - settings = tmp - elif isinstance(settings, str): + if isinstance(settings, str): if not os.path.isfile(settings): - LOG.warning(f"input error, file {settings} not found!") - settings = DeepDict(settings) - - if isinstance(settings, DeepDict): - if settings.has_section("use_plugin"): - try: - use_plugin = settings["settings/solver/use_plugin"] - except Exception as e: - LOG.warning("wrong settings path for use_plugin option detected, expecting the path settings/solver/use_plugin!") - solver = self.get_solver(use_plugin) - solver.set_hyperparameters(settings['hyperparameter']) - return solver - LOG.warning("failed to choose a solver, either the config file is missing the section settings/solver/use_plugin, or there might be a typo") + LOG.error(f"input error, file {settings} not found!") + if not ProjectManager.read_config(settings): + LOG.error("failed to read config in ProjectManager!") + return None else: - msg = "unknown input error, expected DeepDict, dict or filename!" + if not ProjectManager.set_config(settings): + LOG.error("failed to set config in ProjectManager!") + return None + + if not ProjectManager.is_ready(): + LOG.error("failed to set config in ProjectManager!") + return None + + try: + solver = self.get_solver(ProjectManager.use_plugin) + except Exception as e: + msg = f"failed to create solver, reason {e}" LOG.error(msg) - raise IOError(msg) - return None + return None + return solver def get_solver(self, name): """ returns a solver by name tag :param name: [str] solver name :return: [Solver] instance """ if not isinstance(name, str): msg = f"Invalid input, str type expected for name, got {type(name)} instead" LOG.error(msg) raise IOError(msg) if name not in self.list_solver(): msg = f"failed solver request, a solver called {name} is not available, " \ f"check for typo or if your plugin failed while loading!" LOG.error(msg) raise LookupError(msg) LOG.debug(f"get_solver({name})") return self._plugins[name] diff --git a/hyppopy/tests/data/iris_svc_parameter.json b/hyppopy/tests/data/iris_svc_parameter.json index 45e2ff2..eb60183 100644 --- a/hyppopy/tests/data/iris_svc_parameter.json +++ b/hyppopy/tests/data/iris_svc_parameter.json @@ -1,23 +1,26 @@ {"hyperparameter": { "C": { "domain": "uniform", "data": "[0,20]", "type": "float" }, "gamma": { "domain": "uniform", "data": "[0.0001,20.0]", "type": "float" }, "kernel": { "domain": "categorical", "data": "[linear, sigmoid, poly, rbf]", "type": "str" } }, "settings": { "solver": { "max_iterations": "50", "use_plugin" : "hyperopt" + }, + "custom": { + "data_path": "C:/path/to/my/data" } }} \ No newline at end of file diff --git a/hyppopy/tests/data/iris_svc_parameter.xml b/hyppopy/tests/data/iris_svc_parameter.xml index 7d9670d..88cbc8c 100644 --- a/hyppopy/tests/data/iris_svc_parameter.xml +++ b/hyppopy/tests/data/iris_svc_parameter.xml @@ -1,25 +1,28 @@ uniform [0,20] float uniform [0.0001,20.0] float categorical [linear,sigmoid,poly,rbf] str 50 optunity + + C:/path/to/my/data + \ No newline at end of file diff --git a/hyppopy/tests/test_solver_factory.py b/hyppopy/tests/test_solver_factory.py index 177a27a..00c1b1c 100644 --- a/hyppopy/tests/test_solver_factory.py +++ b/hyppopy/tests/test_solver_factory.py @@ -1,104 +1,105 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import unittest from sklearn.svm import SVC from sklearn import datasets from sklearn.model_selection import cross_val_score from sklearn.model_selection import train_test_split from hyppopy.solverfactory import SolverFactory +from hyppopy.projectmanager import ProjectManager from hyppopy.globals import TESTDATA_DIR TESTPARAMFILE = os.path.join(TESTDATA_DIR, 'iris_svc_parameter') from hyppopy.deepdict import DeepDict class SolverFactoryTestSuite(unittest.TestCase): def setUp(self): iris = datasets.load_iris() X, X_test, y, y_test = train_test_split(iris.data, iris.target, test_size=0.1, random_state=42) self.my_IRIS_dta = [X, y] def test_solver_loading(self): names = SolverFactory.list_solver() self.assertTrue("hyperopt" in names) self.assertTrue("optunity" in names) def test_iris_solver_execution(self): def my_SVC_loss_func(data, params): clf = SVC(**params) return -cross_val_score(clf, data[0], data[1], cv=3).mean() + ProjectManager.read_config(TESTPARAMFILE + '.xml') solver = SolverFactory.get_solver('optunity') solver.set_data(self.my_IRIS_dta) - solver.read_parameter(TESTPARAMFILE + '.xml') solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() + ProjectManager.read_config(TESTPARAMFILE + '.json') solver = SolverFactory.get_solver('hyperopt') solver.set_data(self.my_IRIS_dta) - solver.read_parameter(TESTPARAMFILE + '.json') solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() def test_create_solver_from_settings_directly(self): def my_SVC_loss_func(data, params): clf = SVC(**params) return -cross_val_score(clf, data[0], data[1], cv=3).mean() solver = SolverFactory.from_settings(TESTPARAMFILE + '.xml') self.assertEqual(solver.name, "optunity") solver.set_data(self.my_IRIS_dta) solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() solver = SolverFactory.from_settings(TESTPARAMFILE + '.json') self.assertEqual(solver.name, "hyperopt") solver.set_data(self.my_IRIS_dta) solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() dd = DeepDict(TESTPARAMFILE + '.json') solver = SolverFactory.from_settings(dd) self.assertEqual(solver.name, "hyperopt") solver.set_data(self.my_IRIS_dta) solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() solver = SolverFactory.from_settings(dd.data) self.assertEqual(solver.name, "hyperopt") solver.set_data(self.my_IRIS_dta) solver.set_loss_function(my_SVC_loss_func) solver.run() solver.get_results() if __name__ == '__main__': unittest.main() diff --git a/hyppopy/workflows/datalaoder/__init__.py b/hyppopy/workflows/datalaoder/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/hyppopy/workflows/datalaoder/dataloader.py b/hyppopy/workflows/datalaoder/dataloader.py deleted file mode 100644 index 3d86ac4..0000000 --- a/hyppopy/workflows/datalaoder/dataloader.py +++ /dev/null @@ -1,34 +0,0 @@ -# DKFZ -# -# -# Copyright (c) German Cancer Research Center, -# Division of Medical and Biological Informatics. -# All rights reserved. -# -# This software is distributed WITHOUT ANY WARRANTY; without -# even the implied warranty of MERCHANTABILITY or FITNESS FOR -# A PARTICULAR PURPOSE. -# -# See LICENSE.txt or http://www.mitk.org for details. -# -# Author: Sven Wanner (s.wanner@dkfz.de) - -import abc - - -class DataLoader(object): - - def __init__(self): - self.data = None - - @abc.abstractmethod - def read(self, **kwargs): - raise NotImplementedError("the read method has to be implemented in classes derived from DataLoader") - - @abc.abstractmethod - def preprocess(self): - pass - - def get(self): - self.preprocess() - return self.data diff --git a/hyppopy/workflows/datalaoder/simpleloader.py b/hyppopy/workflows/datalaoder/simpleloader.py deleted file mode 100644 index 4a8c461..0000000 --- a/hyppopy/workflows/datalaoder/simpleloader.py +++ /dev/null @@ -1,41 +0,0 @@ -# DKFZ -# -# -# Copyright (c) German Cancer Research Center, -# Division of Medical and Biological Informatics. -# All rights reserved. -# -# This software is distributed WITHOUT ANY WARRANTY; without -# even the implied warranty of MERCHANTABILITY or FITNESS FOR -# A PARTICULAR PURPOSE. -# -# See LICENSE.txt or http://www.mitk.org for details. -# -# Author: Sven Wanner (s.wanner@dkfz.de) - -import os -import numpy as np -import pandas as pd - -from hyppopy.workflows.datalaoder.dataloader import DataLoader - - -class SimpleDataLoader(DataLoader): - - def read(self, **kwargs): - if kwargs['data_name'].endswith(".npy"): - if not kwargs['labels_name'].endswith(".npy"): - raise IOError("Expect both data_name and labels_name being of type .npy!") - self.data = [np.load(os.path.join(kwargs['path'], kwargs['data_name'])), np.load(os.path.join(kwargs['path'], kwargs['labels_name']))] - elif kwargs['data_name'].endswith(".csv"): - try: - dataset = pd.read_csv(os.path.join(kwargs['path'], kwargs['data_name'])) - y = dataset[kwargs['labels_name']].values - X = dataset.drop([kwargs['labels_name']], axis=1).values - self.data = [X, y] - except Exception as e: - print("Precondition violation, this usage case expects as data_name a " - "csv file and as label_name a name of a column in this csv table!") - else: - raise NotImplementedError("This combination of data_name and labels_name " - "does not yet exist, feel free to add it") diff --git a/hyppopy/workflows/workflowbase.py b/hyppopy/workflows/workflowbase.py index 61c879b..f66b878 100644 --- a/hyppopy/workflows/workflowbase.py +++ b/hyppopy/workflows/workflowbase.py @@ -1,61 +1,61 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from hyppopy.deepdict import DeepDict from hyppopy.solverfactory import SolverFactory from hyppopy.projectmanager import ProjectManager from hyppopy.globals import SETTINGSCUSTOMPATH, SETTINGSSOLVERPATH import os import abc import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class WorkflowBase(object): def __init__(self): self._solver = SolverFactory.get_solver(ProjectManager.use_plugin) - self.solver.set_hyperparameters(ProjectManager.config['hyperparameter']) + self.solver.set_hyperparameters(ProjectManager.get_hyperparameter()) def run(self): self.setup() self.solver.set_loss_function(self.blackbox_function) self.solver.run() self.test() def get_results(self): return self.solver.get_results() @abc.abstractmethod def setup(self): raise NotImplementedError('the user has to implement this function') @abc.abstractmethod def blackbox_function(self): raise NotImplementedError('the user has to implement this function') @abc.abstractmethod def test(self): pass @property def solver(self): return self._solver