diff --git a/hyppopy/deepdict.py b/hyppopy/deepdict.py index 9a31c8a..a149797 100644 --- a/hyppopy/deepdict.py +++ b/hyppopy/deepdict.py @@ -1,386 +1,435 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import re import json import types import pprint import xmltodict from dicttoxml import dicttoxml from collections import OrderedDict import logging LOG = logging.getLogger('hyppopy') from hyppopy.globals import DEEPDICT_XML_ROOT def convert_ordered2std_dict(obj): """ Helper function converting an OrderedDict into a standard lib dict. :param obj: [OrderedDict] """ for key, value in obj.items(): if isinstance(value, OrderedDict): obj[key] = dict(obj[key]) convert_ordered2std_dict(obj[key]) def check_dir_existance(dirname): """ Helper function to check if a directory exists, creating it if not. :param dirname: [str] full path of the directory to check """ if not os.path.exists(dirname): os.mkdir(dirname) class DeepDict(object): """ The DeepDict class represents a nested dictionary with additional functionality compared to a standard lib dict. The data can be accessed and changed vie a pathlike access and dumped or read to .json/.xml files. Initializing instances using defaults creates an empty DeepDict. Using in_data enables to initialize the object instance with data, where in_data can be a dict, or a filepath to a json or xml file. Using path sep the appearance of path passing can be changed, a default data access via path would look like my_dd['target/section/path'] with path_sep='.' like so my_dd['target.section.path'] :param in_data: [dict] or [str], input dict or filename :param path_sep: [str] path separator character """ _data = None _sep = "/" def __init__(self, in_data=None, path_sep="/"): self.clear() self._sep = path_sep LOG.debug(f"path separator is: {self._sep}") if in_data is not None: if isinstance(in_data, str): self.from_file(in_data) elif isinstance(in_data, dict): self.data = in_data def __str__(self): """ Enables print output for class instances, printing the instance data dict using pretty print :return: [str] """ return pprint.pformat(self.data) def __eq__(self, other): """ Overloads the == operator comparing the instance data dictionaries for equality :param other: [DeepDict] rhs :return: [bool] """ return self.data == other.data def __getitem__(self, path): """ Overloads the return of the [] operator for data access. This enables access the DeepDict instance like so: my_dd['target/section/path'] or my_dd[['target','section','path']] :param path: [str] or [list(str)], the path to the target data structure level/content :return: [object] """ return DeepDict.get_from_path(self.data, path, self.sep) def __setitem__(self, path, value=None): """ Overloads the setter for the [] operator for data assignment. :param path: [str] or [list(str)], the path to the target data structure level/content :param value: [object] rhs assignment object """ if isinstance(path, str): path = path.split(self.sep) if not isinstance(path, list) or isinstance(path, tuple): raise IOError("Input Error, expect list[str] type for path") if len(path) < 1: raise IOError("Input Error, missing section strings") if not path[0] in self._data.keys(): if value is not None and len(path) == 1: self._data[path[0]] = value else: self._data[path[0]] = {} tmp = self._data[path[0]] path.pop(0) while True: if len(path) == 0: break if path[0] not in tmp.keys(): if value is not None and len(path) == 1: tmp[path[0]] = value else: tmp[path[0]] = {} tmp = tmp[path[0]] else: tmp = tmp[path[0]] path.pop(0) def __len__(self): return len(self._data) + def items(self): + return self.data.items() + def clear(self): """ clears the instance data """ LOG.debug("clear()") self._data = {} def from_file(self, fname): """ Loads data from file. Currently implemented .json and .xml file reader. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if fname.endswith(".json"): self.read_json(fname) elif fname.endswith(".xml"): self.read_xml(fname) else: LOG.error("Unknown filetype, expect [.json, .xml]") raise NotImplementedError("Unknown filetype, expect [.json, .xml]") def read_json(self, fname): """ Read json file :param fname: [str] input filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if not os.path.isfile(fname): raise IOError(f"File {fname} not found!") LOG.debug(f"read_json({fname})") try: with open(fname, "r") as read_file: self._data = json.load(read_file) DeepDict.value_traverse(self.data, callback=DeepDict.parse_type) except Exception as e: LOG.error(f"Error while reading json file {fname} or while converting types") raise IOError("Error while reading json file {fname} or while converting types") def read_xml(self, fname): """ Read xml file :param fname: [str] input filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if not os.path.isfile(fname): raise IOError(f"File {fname} not found!") LOG.debug(f"read_xml({fname})") try: with open(fname, "r") as read_file: xml = "".join(read_file.readlines()) self._data = xmltodict.parse(xml, attr_prefix='') DeepDict.value_traverse(self.data, callback=DeepDict.parse_type) except Exception as e: msg = f"Error while reading xml file {fname} or while converting types" LOG.error(msg) raise IOError(msg) # if written with DeepDict, the xml contains a root node called # deepdict which should beremoved for consistency reasons if DEEPDICT_XML_ROOT in self._data.keys(): self._data = self._data[DEEPDICT_XML_ROOT] self._data = dict(self.data) # convert the orderes dict structure to a default dict for consistency reasons convert_ordered2std_dict(self.data) def to_file(self, fname): """ Write to file, type is determined by checking the filename ending. Currently implemented is writing to json and to xml. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if fname.endswith(".json"): self.write_json(fname) elif fname.endswith(".xml"): self.write_xml(fname) else: LOG.error(f"Unknown filetype, expect [.json, .xml]") raise NotImplementedError("Unknown filetype, expect [.json, .xml]") def write_json(self, fname): """ Dump data to json file. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") check_dir_existance(os.path.dirname(fname)) try: LOG.debug(f"write_json({fname})") with open(fname, "w") as write_file: json.dump(self.data, write_file) except Exception as e: LOG.error(f"Failed dumping to json file: {fname}") raise e def write_xml(self, fname): """ Dump data to json file. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") check_dir_existance(os.path.dirname(fname)) xml = dicttoxml(self.data, custom_root=DEEPDICT_XML_ROOT, attr_type=False) LOG.debug(f"write_xml({fname})") try: with open(fname, "w") as write_file: write_file.write(xml.decode("utf-8")) except Exception as e: LOG.error(f"Failed dumping to xml file: {fname}") raise e def has_section(self, section): return DeepDict.has_key(self.data, section) @staticmethod def get_from_path(data, path, sep="/"): """ Implements a nested dict access via a path like string like so path='target/section/path' which is equivalent to my_dict['target']['section']['path']. :param data: [dict] input dictionary :param path: [str] pathlike string :param sep: [str] path separator, default='/' :return: [object] """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for data") raise IOError("Input Error, expect dict type for data") if isinstance(path, str): path = path.split(sep) if not isinstance(path, list) or isinstance(path, tuple): LOG.error(f"Input Error, expect list[str] type for path: {path}") raise IOError("Input Error, expect list[str] type for path") if not DeepDict.has_key(data, path[-1]): LOG.error(f"Input Error, section {path[-1]} does not exist in dictionary") raise IOError(f"Input Error, section {path[-1]} does not exist in dictionary") try: for k in path: data = data[k] except Exception as e: LOG.error(f"Failed retrieving data from path {path} due to {e}") raise LookupError(f"Failed retrieving data from path {path} due to {e}") return data @staticmethod def has_key(data, section, already_found=False): """ Checks if input dictionary has a key called section. The already_found parameter is for internal recursion checks. :param data: [dict] input dictionary :param section: [str] key string to search for :param already_found: recursion criteria check :return: [bool] section found """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for obj") raise IOError("Input Error, expect dict type for obj") if not isinstance(section, str): LOG.error(f"Input Error, expect dict type for obj {section}") raise IOError(f"Input Error, expect dict type for obj {section}") if already_found: return True found = False for key, value in data.items(): if key == section: found = True if isinstance(value, dict): found = DeepDict.has_key(data[key], section, found) return found @staticmethod def value_traverse(data, callback=None): """ Dictionary filter function, walks through the input dict (obj) calling the callback function for each value. The callback function return is assigned the the corresponding dict value. :param data: [dict] input dictionary :param callback: """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for obj") raise IOError("Input Error, expect dict type for obj") if not isinstance(callback, types.FunctionType): LOG.error("Input Error, expect function type for callback") raise IOError("Input Error, expect function type for callback") for key, value in data.items(): if isinstance(value, dict): DeepDict.value_traverse(data[key], callback) else: data[key] = callback(value) + def transfer_attrs(self, cls, target_section): + def set(item): + setattr(cls, item[0], item[1]) + DeepDict.sectionconstraint_item_traverse(self.data, target_section, callback=set, section=None) + + @staticmethod + def sectionconstraint_item_traverse(data, target_section, callback=None, section=None): + """ + Dictionary filter function, walks through the input dict (obj) calling the callback function for each item. + The callback function then is called with the key value pair as tuple input but only for the target section. + :param data: [dict] input dictionary + :param callback: + """ + if not isinstance(data, dict): + LOG.error("Input Error, expect dict type for obj") + raise IOError("Input Error, expect dict type for obj") + if not isinstance(callback, types.FunctionType): + LOG.error("Input Error, expect function type for callback") + raise IOError("Input Error, expect function type for callback") + for key, value in data.items(): + if isinstance(value, dict): + DeepDict.sectionconstraint_item_traverse(data[key], target_section, callback, key) + else: + if target_section == section: + callback((key, value)) + + @staticmethod + def item_traverse(data, callback=None): + """ + Dictionary filter function, walks through the input dict (obj) calling the callback function for each item. + The callback function then is called with the key value pair as tuple input. + :param data: [dict] input dictionary + :param callback: + """ + if not isinstance(data, dict): + LOG.error("Input Error, expect dict type for obj") + raise IOError("Input Error, expect dict type for obj") + if not isinstance(callback, types.FunctionType): + LOG.error("Input Error, expect function type for callback") + raise IOError("Input Error, expect function type for callback") + for key, value in data.items(): + if isinstance(value, dict): + DeepDict.value_traverse(data[key], callback) + else: + callback((key, value)) + @staticmethod def parse_type(string): """ Type convert input string to float, int, list, tuple or string :param string: [str] input string :return: [T] converted output """ try: a = float(string) try: b = int(string) except ValueError: return float(string) if a == b: return b return a except ValueError: if string.startswith("[") and string.endswith("]"): string = re.sub(' ', '', string) elements = string[1:-1].split(",") li = [] for e in elements: li.append(DeepDict.parse_type(e)) return li elif string.startswith("(") and string.endswith(")"): elements = string[1:-1].split(",") li = [] for e in elements: li.append(DeepDict.parse_type(e)) return tuple(li) return string @property def data(self): return self._data @data.setter def data(self, value): if not isinstance(value, dict): LOG.error(f"Input Error, expect dict type for value, but got {type(value)}") raise IOError(f"Input Error, expect dict type for value, but got {type(value)}") self.clear() self._data = value @property def sep(self): return self._sep @sep.setter def sep(self, value): if not isinstance(value, str): LOG.error(f"Input Error, expect str type for value, but got {type(value)}") raise IOError(f"Input Error, expect str type for value, but got {type(value)}") self._sep = value diff --git a/hyppopy/plugins/hyperopt_solver_plugin.py b/hyppopy/plugins/hyperopt_solver_plugin.py index c94ae6a..7cc784f 100644 --- a/hyppopy/plugins/hyperopt_solver_plugin.py +++ b/hyppopy/plugins/hyperopt_solver_plugin.py @@ -1,70 +1,71 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from pprint import pformat from hyperopt import fmin, tpe, hp, STATUS_OK, STATUS_FAIL, Trials from yapsy.IPlugin import IPlugin +from hyppopy.projectmanager import ProjectManager from hyppopy.solverpluginbase import SolverPluginBase class hyperopt_Solver(SolverPluginBase, IPlugin): trials = None best = None def __init__(self): SolverPluginBase.__init__(self) LOG.debug("initialized") def loss_function(self, params): try: loss = self.loss(self.data, params) status = STATUS_OK except Exception as e: LOG.error(f"execution of self.loss(self.data, params) failed due to:\n {e}") status = STATUS_FAIL return {'loss': loss, 'status': status} def execute_solver(self, parameter): LOG.debug(f"execute_solver using solution space:\n\n\t{pformat(parameter)}\n") self.trials = Trials() try: self.best = fmin(fn=self.loss_function, space=parameter, algo=tpe.suggest, - max_evals=self.settings.max_iterations, + max_evals=ProjectManager.max_iterations, trials=self.trials) except Exception as e: msg = f"internal error in hyperopt.fmin occured. {e}" LOG.error(msg) raise BrokenPipeError(msg) def convert_results(self): txt = "" solution = dict([(k, v) for k, v in self.best.items() if v is not None]) txt += 'Solution Hyperopt Plugin\n========\n' txt += "\n".join(map(lambda x: "%s \t %s" % (x[0], str(x[1])), solution.items())) txt += "\n" return txt diff --git a/hyppopy/plugins/optunity_solver_plugin.py b/hyppopy/plugins/optunity_solver_plugin.py index c92ab52..c5e47f6 100644 --- a/hyppopy/plugins/optunity_solver_plugin.py +++ b/hyppopy/plugins/optunity_solver_plugin.py @@ -1,72 +1,71 @@ -# -*- coding: utf-8 -*- -# # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from pprint import pformat import optunity from yapsy.IPlugin import IPlugin +from hyppopy.projectmanager import ProjectManager from hyppopy.solverpluginbase import SolverPluginBase class optunity_Solver(SolverPluginBase, IPlugin): solver_info = None trials = None best = None status = None def __init__(self): SolverPluginBase.__init__(self) LOG.debug("initialized") def loss_function(self, **params): try: loss = self.loss(self.data, params) self.status.append('ok') return loss except Exception as e: LOG.error(f"computing loss failed due to:\n {e}") self.status.append('fail') return 1e9 def execute_solver(self, parameter): LOG.debug(f"execute_solver using solution space:\n\n\t{pformat(parameter)}\n") self.status = [] try: self.best, self.trials, self.solver_info = optunity.minimize_structured(f=self.loss_function, - num_evals=self.settings.max_iterations, + num_evals=ProjectManager.max_iterations, search_space=parameter) except Exception as e: LOG.error(f"internal error in optunity.minimize_structured occured. {e}") raise BrokenPipeError(f"internal error in optunity.minimize_structured occured. {e}") def convert_results(self): solution = dict([(k, v) for k, v in self.best.items() if v is not None]) txt = "" txt += 'Solution Optunity Plugin\n========\n' txt += "\n".join(map(lambda x: "%s \t %s" % (x[0], str(x[1])), solution.items())) txt += f"\nSolver used: {self.solver_info['solver_name']}" txt += f"\nOptimum: {self.trials.optimum}" txt += f"\nIterations used: {self.trials.stats['num_evals']}" txt += f"\nDuration: {self.trials.stats['time']} s\n" return txt diff --git a/hyppopy/projectmanager.py b/hyppopy/projectmanager.py new file mode 100644 index 0000000..24eed29 --- /dev/null +++ b/hyppopy/projectmanager.py @@ -0,0 +1,55 @@ +# DKFZ +# +# +# Copyright (c) German Cancer Research Center, +# Division of Medical and Biological Informatics. +# All rights reserved. +# +# This software is distributed WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. +# +# See LICENSE.txt or http://www.mitk.org for details. +# +# Author: Sven Wanner (s.wanner@dkfz.de) + +from hyppopy.singleton import * +from hyppopy.deepdict import DeepDict +from hyppopy.globals import SETTINGSCUSTOMPATH, SETTINGSSOLVERPATH + +import os +import logging +from hyppopy.globals import DEBUGLEVEL +LOG = logging.getLogger(os.path.basename(__file__)) +LOG.setLevel(DEBUGLEVEL) + + +@singleton_object +class ProjectManager(metaclass=Singleton): + + def __init__(self): + self.configfilename = None + self.config = None + + def test_config(self): + #TODO test the config structure to fullfill the needs, throwing useful error is not + return True + + def read_config(self, configfile): + self.configfilename = configfile + self.config = DeepDict(configfile) + if not self.test_config(): + self.configfilename = None + self.config = None + return False + + try: + self.config.transfer_attrs(self, SETTINGSCUSTOMPATH.split("/")[-1]) + self.config.transfer_attrs(self, SETTINGSSOLVERPATH.split("/")[-1]) + except Exception as e: + msg = f"transfering custom section as class attributes failed, " \ + f"is the config path to your custom section correct? {SETTINGSCUSTOMPATH}. Exception {e}" + LOG.error(msg) + raise LookupError(msg) + + return True diff --git a/hyppopy/settingspluginbase.py b/hyppopy/settingspluginbase.py index 07e0d37..3d8d49a 100644 --- a/hyppopy/settingspluginbase.py +++ b/hyppopy/settingspluginbase.py @@ -1,89 +1,77 @@ -# -*- coding: utf-8 -*- -# # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import abc import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from hyppopy.globals import SETTINGSSOLVERPATH, SETTINGSCUSTOMPATH from hyppopy.deepdict import DeepDict class SettingsPluginBase(object): _data = None _name = None def __init__(self): - self._data = DeepDict() + self._data = {} @abc.abstractmethod def convert_parameter(self): raise NotImplementedError('users must define convert_parameter to use this base class') def get_hyperparameter(self): - return self.convert_parameter(self.data["hyperparameter"]) + return self.convert_parameter(self.data) def set(self, data): self.data.clear() self.data = data def read(self, fname): self.data.clear() self.data.from_file(fname) def write(self, fname): self.data.to_file(fname) - def set_attributes(self, cls): - if self.data.has_section(SETTINGSSOLVERPATH.split('/')[-1]): - attrs_sec = self.data[SETTINGSSOLVERPATH] - for key, value in attrs_sec.items(): - setattr(cls, key, value) - if self.data.has_section(SETTINGSCUSTOMPATH.split('/')[-1]): - attrs_sec = self.data[SETTINGSCUSTOMPATH] - for key, value in attrs_sec.items(): - setattr(cls, key, value) - @property def data(self): return self._data @data.setter def data(self, value): if isinstance(value, dict): - self._data.data = value - elif isinstance(value, DeepDict): self._data = value + elif isinstance(value, DeepDict): + self._data = value.data else: raise IOError(f"unexpected input type({type(value)}) for data, needs to be of type dict or DeepDict!") @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead") raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead") self._name = value diff --git a/hyppopy/solver.py b/hyppopy/solver.py index 78d3917..1ae9740 100644 --- a/hyppopy/solver.py +++ b/hyppopy/solver.py @@ -1,84 +1,77 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class Solver(object): _name = None _solver_plugin = None _settings_plugin = None def __init__(self): pass def set_data(self, data): self.solver.set_data(data) - def set_parameters(self, params): + def set_hyperparameters(self, params): self.settings.set(params) - self.settings.set_attributes(self.solver) - self.settings.set_attributes(self.settings) - - def read_parameter(self, fname): - self.settings.read(fname) - self.settings.set_attributes(self.solver) - self.settings.set_attributes(self.settings) def set_loss_function(self, loss_func): self.solver.set_loss_function(loss_func) def run(self): self.solver.settings = self.settings self.solver.run() def get_results(self): return self.solver.get_results() @property def is_ready(self): return self.solver is not None and self.settings is not None @property def solver(self): return self._solver_plugin @solver.setter def solver(self, value): self._solver_plugin = value @property def settings(self): return self._settings_plugin @settings.setter def settings(self, value): self._settings_plugin = value @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead") raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead") self._name = value diff --git a/hyppopy/solverfactory.py b/hyppopy/solverfactory.py index cfbe1ed..697c33c 100644 --- a/hyppopy/solverfactory.py +++ b/hyppopy/solverfactory.py @@ -1,156 +1,156 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from yapsy.PluginManager import PluginManager from hyppopy.globals import PLUGIN_DEFAULT_DIR from hyppopy.deepdict import DeepDict from hyppopy.solver import Solver from hyppopy.singleton import * import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) @singleton_object class SolverFactory(metaclass=Singleton): """ This class is responsible for grabbing all plugins from the plugin folder arranging them into a Solver class instances. These Solver class instances can be requested from the factory via the get_solver method. The SolverFactory class is a Singleton class, so try not to instantiate it using SolverFactory(), the consequences will be horrific. Instead use is like a class having static functions only, SolverFactory.method(). """ _plugin_dirs = [] _plugins = {} def __init__(self): print("Solverfactory: I'am alive!") self.reset() self.load_plugins() LOG.debug("Solverfactory initialized") def load_plugins(self): """ Load plugin modules from plugin paths """ LOG.debug("load_plugins()") manager = PluginManager() LOG.debug(f"setPluginPlaces(" + " ".join(map(str, self._plugin_dirs))) manager.setPluginPlaces(self._plugin_dirs) manager.collectPlugins() for plugin in manager.getAllPlugins(): name_elements = plugin.plugin_object.__class__.__name__.split("_") LOG.debug("found plugin " + " ".join(map(str, name_elements))) if len(name_elements) != 2 or ("Solver" not in name_elements and "Settings" not in name_elements): LOG.error(f"invalid plugin class naming for class {plugin.plugin_object.__class__.__name__}, the convention is libname_Solver or libname_Settings.") raise NameError(f"invalid plugin class naming for class {plugin.plugin_object.__class__.__name__}, the convention is libname_Solver or libname_Settings.") if name_elements[0] not in self._plugins.keys(): self._plugins[name_elements[0]] = Solver() self._plugins[name_elements[0]].name = name_elements[0] if name_elements[1] == "Solver": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] self._plugins[name_elements[0]].solver = obj LOG.info(f"plugin: {name_elements[0]} Solver loaded") except Exception as e: LOG.error(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") raise ImportError(f"Failed to instanciate class {plugin.plugin_object.__class__.__name__}") elif name_elements[1] == "Settings": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] self._plugins[name_elements[0]].settings = obj LOG.info(f"plugin: {name_elements[0]} ParameterSpace loaded") except Exception as e: LOG.error(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") raise ImportError(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") else: LOG.error(f"failed loading plugin {name_elements[0]}, please check if naming conventions are kept!") raise IOError(f"failed loading plugin {name_elements[0]}!, please check if naming conventions are kept!") if len(self._plugins) == 0: msg = "no plugins found, please check your plugin folder names or your plugin scripts for errors!" LOG.error(msg) raise IOError(msg) def reset(self): """ Reset solver factory """ LOG.debug("reset()") self._plugins = {} self._plugin_dirs = [] self.add_plugin_dir(os.path.abspath(PLUGIN_DEFAULT_DIR)) def add_plugin_dir(self, dir): """ Add plugin directory """ LOG.debug(f"add_plugin_dir({dir})") self._plugin_dirs.append(dir) def list_solver(self): """ list all solvers available :return: [list(str)] """ return list(self._plugins.keys()) def from_settings(self, settings): if isinstance(settings, dict): tmp = DeepDict() tmp.data = settings settings = tmp elif isinstance(settings, str): if not os.path.isfile(settings): LOG.warning(f"input error, file {settings} not found!") settings = DeepDict(settings) if isinstance(settings, DeepDict): if settings.has_section("use_plugin"): try: use_plugin = settings["settings/solver/use_plugin"] except Exception as e: LOG.warning("wrong settings path for use_plugin option detected, expecting the path settings/solver/use_plugin!") solver = self.get_solver(use_plugin) - solver.set_parameters(settings) + solver.set_hyperparameters(settings['hyperparameter']) return solver LOG.warning("failed to choose a solver, either the config file is missing the section settings/solver/use_plugin, or there might be a typo") else: msg = "unknown input error, expected DeepDict, dict or filename!" LOG.error(msg) raise IOError(msg) return None def get_solver(self, name): """ returns a solver by name tag :param name: [str] solver name :return: [Solver] instance """ if not isinstance(name, str): msg = f"Invalid input, str type expected for name, got {type(name)} instead" LOG.error(msg) raise IOError(msg) if name not in self.list_solver(): msg = f"failed solver request, a solver called {name} is not available, " \ f"check for typo or if your plugin failed while loading!" LOG.error(msg) raise LookupError(msg) LOG.debug(f"get_solver({name})") return self._plugins[name] diff --git a/hyppopy/solverpluginbase.py b/hyppopy/solverpluginbase.py index 63a16d7..e880e80 100644 --- a/hyppopy/solverpluginbase.py +++ b/hyppopy/solverpluginbase.py @@ -1,86 +1,84 @@ -# -*- coding: utf-8 -*- -# # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import abc import os import logging from hyppopy.globals import DEBUGLEVEL from hyppopy.settingspluginbase import SettingsPluginBase LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class SolverPluginBase(object): data = None loss = None _settings = None _name = None def __init__(self): pass @abc.abstractmethod def loss_function(self, params): raise NotImplementedError('users must define loss_func to use this base class') @abc.abstractmethod def execute_solver(self): raise NotImplementedError('users must define execute_solver to use this base class') @abc.abstractmethod def convert_results(self): raise NotImplementedError('users must define convert_results to use this base class') def set_data(self, data): self.data = data def set_loss_function(self, func): self.loss = func def get_results(self): return self.convert_results() def run(self): self.execute_solver(self.settings.get_hyperparameter()) @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): msg = f"Invalid input, str type expected for value, got {type(value)} instead" LOG.error(msg) raise IOError(msg) self._name = value @property def settings(self): return self._settings @settings.setter def settings(self, value): if not isinstance(value, SettingsPluginBase): msg = f"Invalid input, SettingsPluginBase type expected for value, got {type(value)} instead" LOG.error(msg) raise IOError(msg) self._settings = value diff --git a/hyppopy/tests/data/Iris/rf_config.json b/hyppopy/tests/data/Iris/rf_config.json index 869c319..443e510 100644 --- a/hyppopy/tests/data/Iris/rf_config.json +++ b/hyppopy/tests/data/Iris/rf_config.json @@ -1,42 +1,43 @@ {"hyperparameter": { "n_estimators": { "domain": "uniform", "data": "[3,500]", "type": "int" }, "criterion": { "domain": "categorical", "data": "[gini,entropy]", "type": "str" }, "max_depth": { "domain": "uniform", "data": "[3, 50]", "type": "int" }, "min_samples_split": { "domain": "uniform", "data": "[0.0001,1]", "type": "float" }, "min_samples_leaf": { "domain": "uniform", "data": "[0.0001,0.5]", "type": "float" }, "max_features": { "domain": "categorical", "data": "[auto,sqrt,log2]", "type": "str" } }, "settings": { "solver": { "max_iterations": "3", "use_plugin" : "optunity" }, "custom": { + "data_path": "D:/Projects/Python/hyppopy/hyppopy/tests/data/Iris", "data_name": "train_data.npy", "labels_name": "train_labels.npy" } }} \ No newline at end of file diff --git a/hyppopy/tests/data/Iris/rf_config.xml b/hyppopy/tests/data/Iris/rf_config.xml index 925d164..b60530d 100644 --- a/hyppopy/tests/data/Iris/rf_config.xml +++ b/hyppopy/tests/data/Iris/rf_config.xml @@ -1,44 +1,45 @@ uniform [3,200] int categorical [gini,entropy] str uniform [3, 50] int uniform [0.0001,1] float uniform [0.0001,0.5] float categorical [auto,sqrt,log2] str 3 optunity + D:/Projects/Python/hyppopy/hyppopy/tests/data/Iris train_data.npy train_labels.npy \ No newline at end of file diff --git a/hyppopy/tests/data/Iris/svc_config.json b/hyppopy/tests/data/Iris/svc_config.json index 0628f97..59fb433 100644 --- a/hyppopy/tests/data/Iris/svc_config.json +++ b/hyppopy/tests/data/Iris/svc_config.json @@ -1,32 +1,33 @@ {"hyperparameter": { "C": { "domain": "uniform", "data": "[0,20]", "type": "float" }, "gamma": { "domain": "uniform", "data": "[0.0001,20.0]", "type": "float" }, "kernel": { "domain": "categorical", "data": "[linear, sigmoid, poly, rbf]", "type": "str" }, "decision_function_shape": { "domain": "categorical", "data": "[ovo,ovr]", "type": "str" } }, "settings": { "solver": { "max_iterations": "3", "use_plugin" : "optunity" }, "custom": { + "data_path": "D:/Projects/Python/hyppopy/hyppopy/tests/data/Iris", "data_name": "train_data.npy", "labels_name": "train_labels.npy" } }} \ No newline at end of file diff --git a/hyppopy/tests/data/Iris/svc_config.xml b/hyppopy/tests/data/Iris/svc_config.xml index fb4f50b..f8ab4e3 100644 --- a/hyppopy/tests/data/Iris/svc_config.xml +++ b/hyppopy/tests/data/Iris/svc_config.xml @@ -1,34 +1,35 @@ uniform [0,20] float uniform [0.0001,20.0] float categorical [linear,sigmoid,poly,rbf] str categorical [ovo,ovr] str 3 hyperopt + D:/Projects/Python/hyppopy/hyppopy/tests/data/Iris train_data.npy train_labels.npy \ No newline at end of file diff --git a/hyppopy/tests/data/Titanic/rf_config.json b/hyppopy/tests/data/Titanic/rf_config.json index 7993c78..a637f35 100644 --- a/hyppopy/tests/data/Titanic/rf_config.json +++ b/hyppopy/tests/data/Titanic/rf_config.json @@ -1,27 +1,28 @@ {"hyperparameter": { "n_estimators": { "domain": "uniform", "data": "[3,500]", "type": "int" }, "criterion": { "domain": "categorical", "data": "[gini,entropy]", "type": "str" }, "max_depth": { "domain": "uniform", "data": "[3, 50]", "type": "int" } }, "settings": { "solver": { "max_iterations": "3", "use_plugin" : "optunity" }, "custom": { + "data_path": "D:/Projects/Python/hyppopy/hyppopy/tests/data/Titanic", "data_name": "train_cleaned.csv", "labels_name": "Survived" } }} \ No newline at end of file diff --git a/hyppopy/tests/data/Titanic/rf_config.xml b/hyppopy/tests/data/Titanic/rf_config.xml index 5dd0797..fbfa828 100644 --- a/hyppopy/tests/data/Titanic/rf_config.xml +++ b/hyppopy/tests/data/Titanic/rf_config.xml @@ -1,29 +1,30 @@ uniform [3,200] int categorical [gini,entropy] str uniform [3, 50] int 3 optunity + D:/Projects/Python/hyppopy/hyppopy/tests/data/Titanic train_cleaned.csv Survived \ No newline at end of file diff --git a/hyppopy/tests/data/Titanic/svc_config.json b/hyppopy/tests/data/Titanic/svc_config.json index 3291024..4066bb6 100644 --- a/hyppopy/tests/data/Titanic/svc_config.json +++ b/hyppopy/tests/data/Titanic/svc_config.json @@ -1,32 +1,33 @@ {"hyperparameter": { "C": { "domain": "uniform", "data": "[0,20]", "type": "float" }, "gamma": { "domain": "uniform", "data": "[0.0001,20.0]", "type": "float" }, "kernel": { "domain": "categorical", "data": "[linear, sigmoid, poly, rbf]", "type": "str" }, "decision_function_shape": { "domain": "categorical", "data": "[ovo,ovr]", "type": "str" } }, "settings": { "solver": { "max_iterations": "3", "use_plugin" : "hyperopt" }, "custom": { + "data_path": "D:/Projects/Python/hyppopy/hyppopy/tests/data/Titanic", "data_name": "train_cleaned.csv", "labels_name": "Survived" } }} \ No newline at end of file diff --git a/hyppopy/tests/data/Titanic/svc_config.xml b/hyppopy/tests/data/Titanic/svc_config.xml index b26c191..1107c4a 100644 --- a/hyppopy/tests/data/Titanic/svc_config.xml +++ b/hyppopy/tests/data/Titanic/svc_config.xml @@ -1,34 +1,35 @@ uniform [0,20] float uniform [0.0001,20.0] float categorical [linear,sigmoid,poly,rbf] str categorical [ovo,ovr] str 3 optunity + D:/Projects/Python/hyppopy/hyppopy/tests/data/Titanic train_cleaned.csv Survived \ No newline at end of file diff --git a/hyppopy/tests/test_deepdict.py b/hyppopy/tests/test_deepdict.py index 0868cf5..fc5efe8 100644 --- a/hyppopy/tests/test_deepdict.py +++ b/hyppopy/tests/test_deepdict.py @@ -1,153 +1,163 @@ -# -*- coding: utf-8 -*- -# # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import unittest from hyppopy.deepdict import DeepDict DATA_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") class DeepDictTestSuite(unittest.TestCase): def setUp(self): self.test_data = { 'widget': { 'debug': 'on', 'image': {'alignment': 'center', 'hOffset': 250, 'name': 'sun1', 'src': 'Images/Sun.png', 'vOffset': 250}, 'text': {'alignment': 'center', 'data': 'Click Here', 'hOffset': 250, 'name': 'text1', 'onMouseUp': 'sun1.opacity = (sun1.opacity / 100) * 90;', 'size': 36, 'style': 'bold', 'vOffset': 100}, 'window': {'height': 500, 'name': 'main_window', 'title': 'Sample Konfabulator Widget', 'width': 500} } } self.test_data2 = {"test": { "section": { "var1": 100, "var2": 200 } }} def test_fileIO(self): dd_json = DeepDict(os.path.join(DATA_PATH, 'test_json.json')) dd_xml = DeepDict(os.path.join(DATA_PATH, 'test_xml.xml')) dd_dict = DeepDict(self.test_data) self.assertTrue(list(self.test_data.keys())[0] == list(dd_json.data.keys())[0]) self.assertTrue(list(self.test_data.keys())[0] == list(dd_xml.data.keys())[0]) self.assertTrue(list(self.test_data.keys())[0] == list(dd_dict.data.keys())[0]) for key in self.test_data['widget'].keys(): self.assertTrue(self.test_data['widget'][key] == dd_json.data['widget'][key]) self.assertTrue(self.test_data['widget'][key] == dd_xml.data['widget'][key]) self.assertTrue(self.test_data['widget'][key] == dd_dict.data['widget'][key]) for key in self.test_data['widget'].keys(): if key == 'debug': self.assertTrue(dd_json.data['widget']["debug"] == "on") self.assertTrue(dd_xml.data['widget']["debug"] == "on") self.assertTrue(dd_dict.data['widget']["debug"] == "on") else: for key2, value2 in self.test_data['widget'][key].items(): self.assertTrue(value2 == dd_json.data['widget'][key][key2]) self.assertTrue(value2 == dd_xml.data['widget'][key][key2]) self.assertTrue(value2 == dd_dict.data['widget'][key][key2]) dd_dict.to_file(os.path.join(DATA_PATH, 'write_to_json_test.json')) dd_dict.to_file(os.path.join(DATA_PATH, 'write_to_xml_test.xml')) self.assertTrue(os.path.isfile(os.path.join(DATA_PATH, 'write_to_json_test.json'))) self.assertTrue(os.path.isfile(os.path.join(DATA_PATH, 'write_to_xml_test.xml'))) dd_json = DeepDict(os.path.join(DATA_PATH, 'write_to_json_test.json')) dd_xml = DeepDict(os.path.join(DATA_PATH, 'write_to_xml_test.xml')) self.assertTrue(dd_json == dd_dict) self.assertTrue(dd_xml == dd_dict) try: os.remove(os.path.join(DATA_PATH, 'write_to_json_test.json')) os.remove(os.path.join(DATA_PATH, 'write_to_xml_test.xml')) except Exception as e: print(e) print("Warning: Failed to delete temporary data during tests!") def test_has_section(self): dd = DeepDict(self.test_data) self.assertTrue(dd.has_section('hOffset')) self.assertTrue(dd.has_section('window')) self.assertTrue(dd.has_section('widget')) self.assertTrue(dd.has_section('style')) self.assertTrue(dd.has_section('window')) self.assertTrue(dd.has_section('title')) self.assertFalse(dd.has_section('notasection')) def test_data_access(self): dd = DeepDict(self.test_data) self.assertEqual(dd['widget/window/height'], 500) self.assertEqual(dd['widget/image/name'], 'sun1') self.assertTrue(isinstance(dd['widget/window'], dict)) self.assertEqual(len(dd['widget/window']), 4) dd = DeepDict(path_sep=".") dd.data = self.test_data self.assertEqual(dd['widget.window.height'], 500) self.assertEqual(dd['widget.image.name'], 'sun1') self.assertTrue(isinstance(dd['widget.window'], dict)) self.assertEqual(len(dd['widget.window']), 4) def test_data_adding(self): dd = DeepDict() dd["test/section/var1"] = 100 dd["test/section/var2"] = 200 self.assertTrue(dd.data == self.test_data2) dd = DeepDict() dd["test"] = {} dd["test/section"] = {} dd["test/section/var1"] = 100 dd["test/section/var2"] = 200 self.assertTrue(dd.data == self.test_data2) def test_sample_space(self): dd = DeepDict(os.path.join(DATA_PATH, 'test_paramset.json')) self.assertEqual(len(dd[['parameter', 'activation', 'data']]), 4) self.assertEqual(dd['parameter/activation/data'], ['ReLU', 'tanh', 'sigm', 'ELU']) self.assertTrue(isinstance(dd['parameter/activation/data'], list)) self.assertTrue(isinstance(dd['parameter/activation/data'][0], str)) self.assertEqual(dd['parameter/layerdepth/data'], [3, 20]) self.assertTrue(isinstance(dd['parameter/layerdepth/data'], list)) self.assertTrue(isinstance(dd['parameter/layerdepth/data'][0], int)) self.assertTrue(isinstance(dd['parameter/learningrate/data'][0], float)) self.assertEqual(dd['parameter/learningrate/data'][0], 1e-5) self.assertEqual(dd['parameter/learningrate/data'][1], 10.0) def test_len(self): dd = DeepDict(os.path.join(DATA_PATH, 'test_paramset.json')) self.assertEqual(len(dd), 1) + def test_setattr(self): + dd = DeepDict(os.path.join(DATA_PATH, 'iris_svc_parameter.xml')) + + class Foo(object): + def __init__(self): + pass + foo = Foo + dd.transfer_attrs(foo, 'solver') + self.assertEqual(foo.max_iterations, 50) + self.assertEqual(foo.use_plugin, 'optunity') + + if __name__ == '__main__': unittest.main() diff --git a/hyppopy/tests/test_projectmanager.py b/hyppopy/tests/test_projectmanager.py new file mode 100644 index 0000000..0e52fe1 --- /dev/null +++ b/hyppopy/tests/test_projectmanager.py @@ -0,0 +1,38 @@ +# DKFZ +# +# +# Copyright (c) German Cancer Research Center, +# Division of Medical and Biological Informatics. +# All rights reserved. +# +# This software is distributed WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. +# +# See LICENSE.txt or http://www.mitk.org for details. +# +# Author: Sven Wanner (s.wanner@dkfz.de) + +import os +import unittest +from hyppopy.projectmanager import ProjectManager + + +DATA_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") + + +class ProjectManagerTestSuite(unittest.TestCase): + + def setUp(self): + pass + + def test_attr_transfer(self): + ProjectManager.read_config(os.path.join(DATA_PATH, *('Titanic', 'rf_config.xml'))) + self.assertEqual(ProjectManager.data_name, 'train_cleaned.csv') + self.assertEqual(ProjectManager.labels_name, 'Survived') + self.assertEqual(ProjectManager.max_iterations, 3) + self.assertEqual(ProjectManager.use_plugin, 'optunity') + + +if __name__ == '__main__': + unittest.main() diff --git a/hyppopy/tests/test_workflows.py b/hyppopy/tests/test_workflows.py index 8071a1a..6b2c848 100644 --- a/hyppopy/tests/test_workflows.py +++ b/hyppopy/tests/test_workflows.py @@ -1,130 +1,96 @@ -# -*- coding: utf-8 -*- -# # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import unittest from hyppopy.globals import TESTDATA_DIR IRIS_DATA = os.path.join(TESTDATA_DIR, 'Iris') TITANIC_DATA = os.path.join(TESTDATA_DIR, 'Titanic') +from hyppopy.projectmanager import ProjectManager from hyppopy.workflows.svc_usecase.svc_usecase import svc_usecase from hyppopy.workflows.randomforest_usecase.randomforest_usecase import randomforest_usecase -class Args(object): - - def __init__(self): - pass - - def set_arg(self, name, value): - setattr(self, name, value) - - class WorkflowTestSuite(unittest.TestCase): def setUp(self): self.results = [] def test_workflow_svc_on_iris_from_xml(self): - svc_args_xml = Args() - svc_args_xml.set_arg('plugin', '') - svc_args_xml.set_arg('data', IRIS_DATA) - svc_args_xml.set_arg('config', os.path.join(IRIS_DATA, 'svc_config.xml')) - uc = svc_usecase(svc_args_xml) + ProjectManager.read_config(os.path.join(IRIS_DATA, 'svc_config.xml')) + uc = svc_usecase() uc.run() self.results.append(uc.get_results()) self.assertTrue(uc.get_results().find("Solution") != -1) def test_workflow_rf_on_iris_from_xml(self): - rf_args_xml = Args() - rf_args_xml.set_arg('plugin', '') - rf_args_xml.set_arg('data', IRIS_DATA) - rf_args_xml.set_arg('config', os.path.join(IRIS_DATA, 'rf_config.xml')) - uc = svc_usecase(rf_args_xml) + ProjectManager.read_config(os.path.join(IRIS_DATA, 'rf_config.xml')) + uc = svc_usecase() uc.run() self.results.append(uc.get_results()) self.assertTrue(uc.get_results().find("Solution") != -1) def test_workflow_svc_on_iris_from_json(self): - svc_args_json = Args() - svc_args_json.set_arg('plugin', '') - svc_args_json.set_arg('data', IRIS_DATA) - svc_args_json.set_arg('config', os.path.join(IRIS_DATA, 'svc_config.json')) - uc = svc_usecase(svc_args_json) + ProjectManager.read_config(os.path.join(IRIS_DATA, 'svc_config.json')) + uc = svc_usecase() uc.run() self.results.append(uc.get_results()) self.assertTrue(uc.get_results().find("Solution") != -1) def test_workflow_rf_on_iris_from_json(self): - rf_args_json = Args() - rf_args_json.set_arg('plugin', '') - rf_args_json.set_arg('data', IRIS_DATA) - rf_args_json.set_arg('config', os.path.join(IRIS_DATA, 'rf_config.json')) - uc = randomforest_usecase(rf_args_json) + ProjectManager.read_config(os.path.join(IRIS_DATA, 'rf_config.json')) + uc = randomforest_usecase() uc.run() self.results.append(uc.get_results()) self.assertTrue(uc.get_results().find("Solution") != -1) - def test_workflow_svc_on_titanic_from_xml(self): - svc_args_xml = Args() - svc_args_xml.set_arg('plugin', '') - svc_args_xml.set_arg('data', TITANIC_DATA) - svc_args_xml.set_arg('config', os.path.join(TITANIC_DATA, 'svc_config.xml')) - uc = svc_usecase(svc_args_xml) - uc.run() - self.results.append(uc.get_results()) - self.assertTrue(uc.get_results().find("Solution") != -1) + # def test_workflow_svc_on_titanic_from_xml(self): + # ProjectManager.read_config(os.path.join(TITANIC_DATA, 'svc_config.xml')) + # uc = svc_usecase() + # uc.run() + # self.results.append(uc.get_results()) + # self.assertTrue(uc.get_results().find("Solution") != -1) def test_workflow_rf_on_titanic_from_xml(self): - rf_args_xml = Args() - rf_args_xml.set_arg('plugin', '') - rf_args_xml.set_arg('data', TITANIC_DATA) - rf_args_xml.set_arg('config', os.path.join(TITANIC_DATA, 'rf_config.xml')) - uc = randomforest_usecase(rf_args_xml) + ProjectManager.read_config(os.path.join(TITANIC_DATA, 'rf_config.xml')) + uc = randomforest_usecase() uc.run() self.results.append(uc.get_results()) self.assertTrue(uc.get_results().find("Solution") != -1) - def test_workflow_svc_on_titanic_from_json(self): - svc_args_json = Args() - svc_args_json.set_arg('plugin', '') - svc_args_json.set_arg('data', TITANIC_DATA) - svc_args_json.set_arg('config', os.path.join(TITANIC_DATA, 'svc_config.json')) - uc = svc_usecase(svc_args_json) - uc.run() - self.results.append(uc.get_results()) - self.assertTrue(uc.get_results().find("Solution") != -1) + # def test_workflow_svc_on_titanic_from_json(self): + # ProjectManager.read_config(os.path.join(TITANIC_DATA, 'svc_config.json')) + # uc = svc_usecase() + # uc.run() + # self.results.append(uc.get_results()) + # self.assertTrue(uc.get_results().find("Solution") != -1) def test_workflow_rf_on_titanic_from_json(self): - rf_args_json = Args() - rf_args_json.set_arg('plugin', '') - rf_args_json.set_arg('data', TITANIC_DATA) - rf_args_json.set_arg('config', os.path.join(TITANIC_DATA, 'rf_config.json')) - uc = randomforest_usecase(rf_args_json) + ProjectManager.read_config(os.path.join(TITANIC_DATA, 'rf_config.json')) + uc = randomforest_usecase() uc.run() self.results.append(uc.get_results()) self.assertTrue(uc.get_results().find("Solution") != -1) def tearDown(self): print("") for r in self.results: print(r) if __name__ == '__main__': unittest.main() diff --git a/hyppopy/workflows/dataloader/__init__.py b/hyppopy/workflows/dataloader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hyppopy/workflows/dataloader/dataloaderbase.py b/hyppopy/workflows/dataloader/dataloaderbase.py new file mode 100644 index 0000000..83cd117 --- /dev/null +++ b/hyppopy/workflows/dataloader/dataloaderbase.py @@ -0,0 +1,36 @@ +# DKFZ +# +# +# Copyright (c) German Cancer Research Center, +# Division of Medical and Biological Informatics. +# All rights reserved. +# +# This software is distributed WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. +# +# See LICENSE.txt or http://www.mitk.org for details. +# +# Author: Sven Wanner (s.wanner@dkfz.de) + +import abc + + +class DataLoaderBase(object): + + def __init__(self): + self.data = None + + def start(self, **kwargs): + self.read(**kwargs) + if self.data is None: + raise AttributeError("data is empty, did you missed to assign it while implementing read...?") + self.preprocess(**kwargs) + + @abc.abstractmethod + def read(self, **kwargs): + raise NotImplementedError("the read method has to be implemented in classes derived from DataLoader") + + @abc.abstractmethod + def preprocess(self, **kwargs): + pass diff --git a/hyppopy/workflows/dataloader/simpleloader.py b/hyppopy/workflows/dataloader/simpleloader.py new file mode 100644 index 0000000..6760cfd --- /dev/null +++ b/hyppopy/workflows/dataloader/simpleloader.py @@ -0,0 +1,41 @@ +# DKFZ +# +# +# Copyright (c) German Cancer Research Center, +# Division of Medical and Biological Informatics. +# All rights reserved. +# +# This software is distributed WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. +# +# See LICENSE.txt or http://www.mitk.org for details. +# +# Author: Sven Wanner (s.wanner@dkfz.de) + +import os +import numpy as np +import pandas as pd + +from hyppopy.workflows.dataloader.dataloaderbase import DataLoaderBase + + +class SimpleDataLoaderBase(DataLoaderBase): + + def read(self, **kwargs): + if kwargs['data_name'].endswith(".npy"): + if not kwargs['labels_name'].endswith(".npy"): + raise IOError("Expect both data_name and labels_name being of type .npy!") + self.data = [np.load(os.path.join(kwargs['path'], kwargs['data_name'])), np.load(os.path.join(kwargs['path'], kwargs['labels_name']))] + elif kwargs['data_name'].endswith(".csv"): + try: + dataset = pd.read_csv(os.path.join(kwargs['path'], kwargs['data_name'])) + y = dataset[kwargs['labels_name']].values + X = dataset.drop([kwargs['labels_name']], axis=1).values + self.data = [X, y] + except Exception as e: + print("Precondition violation, this usage case expects as data_name a " + "csv file and as label_name a name of a column in this csv table!") + else: + raise NotImplementedError("This combination of data_name and labels_name " + "does not yet exist, feel free to add it") diff --git a/hyppopy/workflows/dataloader/unetloader.py b/hyppopy/workflows/dataloader/unetloader.py new file mode 100644 index 0000000..9d71dcb --- /dev/null +++ b/hyppopy/workflows/dataloader/unetloader.py @@ -0,0 +1,79 @@ +# DKFZ +# +# +# Copyright (c) German Cancer Research Center, +# Division of Medical and Biological Informatics. +# All rights reserved. +# +# This software is distributed WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. +# +# See LICENSE.txt or http://www.mitk.org for details. +# +# Author: Sven Wanner (s.wanner@dkfz.de) + +import os +from collections import defaultdict +from hyppopy.workflows.dataloader.dataloaderbase import DataLoaderBase + + +class UnetDataLoaderBase(DataLoaderBase): + + def read(self, **kwargs): + pass + + def subfiles(self, folder, join=True, prefix=None, suffix=None, sort=True): + if join: + l = os.path.join + else: + l = lambda x, y: y + res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) + and (prefix is None or i.startswith(prefix)) + and (suffix is None or i.endswith(suffix))] + if sort: + res.sort() + return res + + def preprocess(self, **kwargs): + image_dir = os.path.join(kwargs['root_dir'], kwargs['image_dir']) + label_dir = os.path.join(kwargs['root_dir'], kwargs['labels_dir']) + output_dir = os.path.join(kwargs['root_dir'], kwargs['output_dir']) + classes = kwargs['classes'] + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print('Created' + output_dir + '...') + + class_stats = defaultdict(int) + total = 0 + + nii_files = self.subfiles(image_dir, suffix=".nii.gz", join=False) + + for i in range(0, len(nii_files)): + if nii_files[i].startswith("._"): + nii_files[i] = nii_files[i][2:] + + for f in nii_files: + image, _ = load(os.path.join(image_dir, f)) + label, _ = load(os.path.join(label_dir, f.replace('_0000', ''))) + + print(f) + + for i in range(classes): + class_stats[i] += np.sum(label == i) + total += np.sum(label == i) + + image = (image - image.min()) / (image.max() - image.min()) + + image = reshape(image, append_value=0, new_shape=(64, 64, 64)) + label = reshape(label, append_value=0, new_shape=(64, 64, 64)) + + result = np.stack((image, label)) + + np.save(os.path.join(output_dir, f.split('.')[0] + '.npy'), result) + print(f) + + print(total) + for i in range(classes): + print(class_stats[i], class_stats[i] / total) diff --git a/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py b/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py index c7ca0bc..5c14b6e 100644 --- a/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py +++ b/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py @@ -1,38 +1,38 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import cross_val_score -from hyppopy.workflows.workflowbase import Workflow -from hyppopy.workflows.datalaoder.simpleloader import SimpleDataLoader +from hyppopy.projectmanager import ProjectManager +from hyppopy.workflows.workflowbase import WorkflowBase +from hyppopy.workflows.dataloader.simpleloader import SimpleDataLoaderBase -class randomforest_usecase(Workflow): - - def __init__(self, args): - Workflow.__init__(self, args) +class randomforest_usecase(WorkflowBase): def setup(self): - dl = SimpleDataLoader() - dl.read(path=self.args.data, data_name=self.solver.settings.data_name, labels_name=self.solver.settings.labels_name) - self.solver.set_data(dl.get()) + dl = SimpleDataLoaderBase() + dl.start(path=ProjectManager.data_path, + data_name=ProjectManager.data_name, + labels_name=ProjectManager.labels_name) + self.solver.set_data(dl.data) def blackbox_function(self, data, params): if "n_estimators" in params.keys(): params["n_estimators"] = int(round(params["n_estimators"])) clf = RandomForestClassifier(**params) return -cross_val_score(estimator=clf, X=data[0], y=data[1], cv=3).mean() diff --git a/hyppopy/workflows/svc_usecase/svc_usecase.py b/hyppopy/workflows/svc_usecase/svc_usecase.py index 18f93f1..4108969 100644 --- a/hyppopy/workflows/svc_usecase/svc_usecase.py +++ b/hyppopy/workflows/svc_usecase/svc_usecase.py @@ -1,41 +1,38 @@ -# -*- coding: utf-8 -*- -# # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import numpy as np import pandas as pd from sklearn.svm import SVC from sklearn.model_selection import cross_val_score -from hyppopy.workflows.workflowbase import Workflow -from hyppopy.workflows.datalaoder.simpleloader import SimpleDataLoader - +from hyppopy.projectmanager import ProjectManager +from hyppopy.workflows.workflowbase import WorkflowBase +from hyppopy.workflows.dataloader.simpleloader import SimpleDataLoaderBase -class svc_usecase(Workflow): - def __init__(self, args): - Workflow.__init__(self, args) +class svc_usecase(WorkflowBase): def setup(self): - dl = SimpleDataLoader() - dl.read(path=self.args.data, data_name=self.solver.settings.data_name, - labels_name=self.solver.settings.labels_name) - self.solver.set_data(dl.get()) + dl = SimpleDataLoaderBase() + dl.start(path=ProjectManager.data_path, + data_name=ProjectManager.data_name, + labels_name=ProjectManager.labels_name) + self.solver.set_data(dl.data) def blackbox_function(self, data, params): clf = SVC(**params) return -cross_val_score(estimator=clf, X=data[0], y=data[1], cv=3).mean() diff --git a/hyppopy/workflows/unet_usecase/unet_uscase_utils.py b/hyppopy/workflows/unet_usecase/unet_uscase_utils.py index ae1f4ce..9b9147c 100644 --- a/hyppopy/workflows/unet_usecase/unet_uscase_utils.py +++ b/hyppopy/workflows/unet_usecase/unet_uscase_utils.py @@ -1,419 +1,417 @@ -# -*- coding: utf-8 -*- -# # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import torch import pickle import fnmatch import numpy as np from torch import nn from medpy.io import load from collections import defaultdict from abc import ABCMeta, abstractmethod def sum_tensor(input, axes, keepdim=False): axes = np.unique(axes).astype(int) if keepdim: for ax in axes: input = input.sum(int(ax), keepdim=True) else: for ax in sorted(axes, reverse=True): input = input.sum(int(ax)) return input def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None): if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]: rebalance_weights = rebalance_weights[1:] # this is the case when use_bg=False axes = tuple([0] + list(range(2, len(net_output.size())))) tp = sum_tensor(net_output * gt, axes, keepdim=False) fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False) fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False) weights = torch.ones(tp.shape) weights[0] = background_weight if net_output.device.type == "cuda": weights = weights.cuda(net_output.device.index) if rebalance_weights is not None: rebalance_weights = torch.from_numpy(rebalance_weights).float() if net_output.device.type == "cuda": rebalance_weights = rebalance_weights.cuda(net_output.device.index) tp = tp * rebalance_weights fn = fn * rebalance_weights result = (- ((2 * tp + smooth_in_nom) / (2 * tp + fp + fn + smooth)) * weights).mean() return result def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1.): axes = tuple(range(2, len(net_output.size()))) intersect = sum_tensor(net_output * gt, axes, keepdim=False) denom = sum_tensor(net_output + gt, axes, keepdim=False) result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth)) * weights).mean() #TODO: Was ist weights and er Stelle? return result class SoftDiceLoss(nn.Module): def __init__(self, smooth=1., apply_nonlin=None, batch_dice=False, do_bg=True, smooth_in_nom=True, background_weight=1, rebalance_weights=None): """ hahaa no documentation for you today :param smooth: :param apply_nonlin: :param batch_dice: :param do_bg: :param smooth_in_nom: :param background_weight: :param rebalance_weights: """ super(SoftDiceLoss, self).__init__() if not do_bg: assert background_weight == 1, "if there is no bg, then set background weight to 1 you dummy" self.rebalance_weights = rebalance_weights self.background_weight = background_weight self.smooth_in_nom = smooth_in_nom self.do_bg = do_bg self.batch_dice = batch_dice self.apply_nonlin = apply_nonlin self.smooth = smooth self.y_onehot = None if not smooth_in_nom: self.nom_smooth = 0 else: self.nom_smooth = smooth def forward(self, x, y): with torch.no_grad(): y = y.long() shp_x = x.shape shp_y = y.shape if self.apply_nonlin is not None: x = self.apply_nonlin(x) if len(shp_x) != len(shp_y): y = y.view((shp_y[0], 1, *shp_y[1:])) # now x and y should have shape (B, C, X, Y(, Z))) and (B, 1, X, Y(, Z))), respectively y_onehot = torch.zeros(shp_x) if x.device.type == "cuda": y_onehot = y_onehot.cuda(x.device.index) y_onehot.scatter_(1, y, 1) if not self.do_bg: x = x[:, 1:] y_onehot = y_onehot[:, 1:] if not self.batch_dice: if self.background_weight != 1 or (self.rebalance_weights is not None): raise NotImplementedError("nah son") l = soft_dice(x, y_onehot, self.smooth, self.smooth_in_nom) else: l = soft_dice_per_batch_2(x, y_onehot, self.smooth, self.smooth_in_nom, background_weight=self.background_weight, rebalance_weights=self.rebalance_weights) return l def load_dataset(base_dir, pattern='*.npy', slice_offset=5, keys=None): fls = [] files_len = [] slices_ax = [] for root, dirs, files in os.walk(base_dir): i = 0 for filename in sorted(fnmatch.filter(files, pattern)): if keys is not None and filename[:-4] in keys: npy_file = os.path.join(root, filename) numpy_array = np.load(npy_file, mmap_mode="r") fls.append(npy_file) files_len.append(numpy_array.shape[1]) slices_ax.extend([(i, j) for j in range(slice_offset, files_len[-1] - slice_offset)]) i += 1 return fls, files_len, slices_ax, class SlimDataLoaderBase(object): def __init__(self, data, batch_size, number_of_threads_in_multithreaded=None): """ Slim version of DataLoaderBase (which is now deprecated). Only provides very simple functionality. You must derive from this class to implement your own DataLoader. You must overrive self.generate_train_batch() If you use our MultiThreadedAugmenter you will need to also set and use number_of_threads_in_multithreaded. See multithreaded_dataloading in examples! :param data: will be stored in self._data. You can use it to generate your batches in self.generate_train_batch() :param batch_size: will be stored in self.batch_size for use in self.generate_train_batch() :param number_of_threads_in_multithreaded: will be stored in self.number_of_threads_in_multithreaded. None per default. If you wish to iterate over all your training data only once per epoch, you must coordinate your Dataloaders and you will need this information """ __metaclass__ = ABCMeta self.number_of_threads_in_multithreaded = number_of_threads_in_multithreaded self._data = data self.batch_size = batch_size self.thread_id = 0 def set_thread_id(self, thread_id): self.thread_id = thread_id def __iter__(self): return self def __next__(self): return self.generate_train_batch() @abstractmethod def generate_train_batch(self): '''override this Generate your batch from self._data .Make sure you generate the correct batch size (self.BATCH_SIZE) ''' pass class NumpyDataLoader(SlimDataLoaderBase): def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, file_pattern='*.npy', label_slice=1, input_slice=(0,), keys=None): self.files, self.file_len, self.slices = load_dataset(base_dir=base_dir, pattern=file_pattern, slice_offset=0, keys=keys, ) super(NumpyDataLoader, self).__init__(self.slices, batch_size, num_batches) self.batch_size = batch_size self.use_next = False if mode == "train": self.use_next = False self.slice_idxs = list(range(0, len(self.slices))) self.data_len = len(self.slices) self.num_batches = min((self.data_len // self.batch_size)+10, num_batches) if isinstance(label_slice, int): label_slice = (label_slice,) self.input_slice = input_slice self.label_slice = label_slice self.np_data = np.asarray(self.slices) def reshuffle(self): print("Reshuffle...") random.shuffle(self.slice_idxs) print("Initializing... this might take a while...") def generate_train_batch(self): open_arr = random.sample(self._data, self.batch_size) return self.get_data_from_array(open_arr) def __len__(self): n_items = min(self.data_len // self.batch_size, self.num_batches) return n_items def __getitem__(self, item): slice_idxs = self.slice_idxs data_len = len(self.slices) np_data = self.np_data if item > len(self): raise StopIteration() if (item * self.batch_size) == data_len: raise StopIteration() start_idx = (item * self.batch_size) % data_len stop_idx = ((item + 1) * self.batch_size) % data_len if ((item + 1) * self.batch_size) == data_len: stop_idx = data_len if stop_idx > start_idx: idxs = slice_idxs[start_idx:stop_idx] else: raise StopIteration() open_arr = np_data[idxs] return self.get_data_from_array(open_arr) def get_data_from_array(self, open_array): data = [] fnames = [] slice_idxs = [] labels = [] for slice in open_array: fn_name = self.files[slice[0]] numpy_array = np.load(fn_name, mmap_mode="r") numpy_slice = numpy_array[ :, slice[1], ] data.append(numpy_slice[None, self.input_slice[0]]) # 'None' keeps the dimension if self.label_slice is not None: labels.append(numpy_slice[None, self.label_slice[0]]) # 'None' keeps the dimension fnames.append(self.files[slice[0]]) slice_idxs.append(slice[1]) ret_dict = {'data': np.asarray(data), 'fnames': fnames, 'slice_idxs': slice_idxs} if self.label_slice is not None: ret_dict['seg'] = np.asarray(labels) return ret_dict class NumpyDataSet(object): """ TODO """ def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, num_processes=8, num_cached_per_queue=8 * 4, target_size=128, file_pattern='*.npy', label_slice=1, input_slice=(0,), do_reshuffle=True, keys=None): data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, seed=seed, file_pattern=file_pattern, input_slice=input_slice, label_slice=label_slice, keys=keys) self.data_loader = data_loader self.batch_size = batch_size self.do_reshuffle = do_reshuffle self.number_of_slices = 1 self.transforms = get_transforms(mode=mode, target_size=target_size) self.augmenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=num_processes, num_cached_per_queue=num_cached_per_queue, seeds=seed, shuffle=do_reshuffle) self.augmenter.restart() def __len__(self): return len(self.data_loader) def __iter__(self): if self.do_reshuffle: self.data_loader.reshuffle() self.augmenter.renew() return self.augmenter def __next__(self): return next(self.augmenter) def reshape(orig_img, append_value=-1024, new_shape=(512, 512, 512)): reshaped_image = np.zeros(new_shape) reshaped_image[...] = append_value x_offset = 0 y_offset = 0 # (new_shape[1] - orig_img.shape[1]) // 2 z_offset = 0 # (new_shape[2] - orig_img.shape[2]) // 2 reshaped_image[x_offset:orig_img.shape[0] + x_offset, y_offset:orig_img.shape[1] + y_offset, z_offset:orig_img.shape[2] + z_offset] = orig_img # insert temp_img.min() as background value return reshaped_image def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): if join: l = os.path.join else: l = lambda x, y: y res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) and (prefix is None or i.startswith(prefix)) and (suffix is None or i.endswith(suffix))] if sort: res.sort() return res def preprocess_data(root_dir): print("preprocess data...") image_dir = os.path.join(root_dir, 'imagesTr') print(f"image_dir: {image_dir}") label_dir = os.path.join(root_dir, 'labelsTr') print(f"label_dir: {label_dir}") output_dir = os.path.join(root_dir, 'preprocessed') print(f"output_dir: {output_dir} ... ", end="") classes = 3 if not os.path.exists(output_dir): os.makedirs(output_dir) print("created!") else: print("found!\npreprocessed data already available, aborted preprocessing!") return False print("start preprocessing ... ", end="") class_stats = defaultdict(int) total = 0 nii_files = subfiles(image_dir, suffix=".nii.gz", join=False) for i in range(0, len(nii_files)): if nii_files[i].startswith("._"): nii_files[i] = nii_files[i][2:] for i, f in enumerate(nii_files): image, _ = load(os.path.join(image_dir, f)) label, _ = load(os.path.join(label_dir, f.replace('_0000', ''))) for i in range(classes): class_stats[i] += np.sum(label == i) total += np.sum(label == i) image = (image - image.min()) / (image.max() - image.min()) image = reshape(image, append_value=0, new_shape=(64, 64, 64)) label = reshape(label, append_value=0, new_shape=(64, 64, 64)) result = np.stack((image, label)) np.save(os.path.join(output_dir, f.split('.')[0] + '.npy'), result) print("finished!") return True def create_splits(output_dir, image_dir): print("creating splits ... ", end="") npy_files = subfiles(image_dir, suffix=".npy", join=False) trainset_size = len(npy_files) * 50 // 100 valset_size = len(npy_files) * 25 // 100 testset_size = len(npy_files) * 25 // 100 splits = [] for split in range(0, 5): image_list = npy_files.copy() trainset = [] valset = [] testset = [] for i in range(0, trainset_size): patient = np.random.choice(image_list) image_list.remove(patient) trainset.append(patient[:-4]) for i in range(0, valset_size): patient = np.random.choice(image_list) image_list.remove(patient) valset.append(patient[:-4]) for i in range(0, testset_size): patient = np.random.choice(image_list) image_list.remove(patient) testset.append(patient[:-4]) split_dict = dict() split_dict['train'] = trainset split_dict['val'] = valset split_dict['test'] = testset splits.append(split_dict) with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f: pickle.dump(splits, f) print("finished!") diff --git a/hyppopy/workflows/unet_usecase/unet_usecase.py b/hyppopy/workflows/unet_usecase/unet_usecase.py index 550400d..4666f20 100644 --- a/hyppopy/workflows/unet_usecase/unet_usecase.py +++ b/hyppopy/workflows/unet_usecase/unet_usecase.py @@ -1,126 +1,33 @@ -# -*- coding: utf-8 -*- -# # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os -import pickle - -import torch -import torch.optim as optim -from torch.optim.lr_scheduler import ReduceLROnPlateau -import torch.nn.functional as F -from networks.RecursiveUNet import UNet - -import hyppopy.solverfactory as sfac -from .unet_uscase_utils import * - - -def unet_usecase(args): - print("Execute UNet UseCase...") - data_dir = args.data - preprocessed_dir = os.path.join(args.data, 'preprocessed') - solver_plugin = args.plugin - config_file = args.config - print(f"input data directory: {data_dir}") - print(f"use plugin: {solver_plugin}") - print(f"config file: {config_file}") - - factory = sfac.SolverFactory.instance() - solver = factory.get_solver(solver_plugin) - solver.read_parameter(config_file) - - if preprocess_data(data_dir): - create_splits(output_dir=data_dir, image_dir=preprocessed_dir) - - with open(os.path.join(data_dir, "splits.pkl"), 'rb') as f: - splits = pickle.load(f) - - tr_keys = splits[solver.settings.fold]['train'] - val_keys = splits[solver.settings.fold]['val'] - test_keys = splits[solver.settings.fold]['test'] - - def loss_function(patch_size, batch_size): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - train_data_loader = NumpyDataSet(data_dir, - target_size=patch_size, - batch_size=batch_size, - keys=tr_keys) - val_data_loader = NumpyDataSet(data_dir, - target_size=patch_size, - batch_size=batch_size, - mode="val", - do_reshuffle=False) - model = UNet(num_classes=solver.settings.num_classes, in_channels=solver.settings.in_channels) - model.to(device) - - dice_loss = SoftDiceLoss(batch_dice=True) - ce_loss = torch.nn.CrossEntropyLoss() - node_optimizer = optim.Adam(model.parameters(), lr=solver.settings.learning_rate) - scheduler = ReduceLROnPlateau(node_optimizer, 'min') - - model.train() - - data = None - batch_counter = 0 - for data_batch in train_data_loader: - - node_optimizer.zero_grad() - - data = data_batch['data'][0].float().to(device) - target = data_batch['seg'][0].long().to(device) - - pred = model(data) - pred_softmax = F.softmax(pred, dim=1) - - loss = dice_loss(pred_softmax, target.squeeze()) + ce_loss(pred, target.squeeze()) - loss.backward() - node_optimizer.step() - batch_counter += 1 - - assert data is not None, 'data is None. Please check if your dataloader works properly' - - model.eval() - - data = None - loss_list = [] - - with torch.no_grad(): - for data_batch in val_data_loader: - data = data_batch['data'][0].float().to(device) - target = data_batch['seg'][0].long().to(device) - - pred = model(data) - pred_softmax = F.softmax(pred) - - loss = dice_loss(pred_softmax, target.squeeze()) + ce_loss(pred, target.squeeze()) - loss_list.append(loss.item()) - - assert data is not None, 'data is None. Please check if your dataloader works properly' - scheduler.step(np.mean(loss_list)) - - data = [] - - # solver.set_data(data) - # solver.read_parameter(config_file) - # solver.set_loss_function(loss_function) - # solver.run() - # solver.get_results() - +import numpy as np +import pandas as pd +from sklearn.svm import SVC +from sklearn.model_selection import cross_val_score +from hyppopy.projectmanager import ProjectManager +from hyppopy.workflows.workflowbase import WorkflowBase +from hyppopy.workflows.dataloader.unetloader import UnetDataLoaderBase +class unet_usecase(WorkflowBase): + def setup(self): + pass + def blackbox_function(self, data, params): + pass diff --git a/hyppopy/workflows/workflowbase.py b/hyppopy/workflows/workflowbase.py index fea491b..61c879b 100644 --- a/hyppopy/workflows/workflowbase.py +++ b/hyppopy/workflows/workflowbase.py @@ -1,76 +1,61 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) -from hyppopy.solverfactory import SolverFactory from hyppopy.deepdict import DeepDict +from hyppopy.solverfactory import SolverFactory +from hyppopy.projectmanager import ProjectManager from hyppopy.globals import SETTINGSCUSTOMPATH, SETTINGSSOLVERPATH import os import abc import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) -class Workflow(object): - _solver = None - _args = None +class WorkflowBase(object): - def __init__(self, args): - self._args = args - if args.plugin is None or args.plugin == '': - dd = DeepDict(args.config) - ppath = "use_plugin" - if not dd.has_section(ppath): - msg = f"invalid config file, missing section {ppath}" - LOG.error(msg) - raise LookupError(msg) - plugin = dd[SETTINGSSOLVERPATH+'/'+ppath] - else: - plugin = args.plugin - self._solver = SolverFactory.get_solver(plugin) - self.solver.read_parameter(args.config) + def __init__(self): + self._solver = SolverFactory.get_solver(ProjectManager.use_plugin) + self.solver.set_hyperparameters(ProjectManager.config['hyperparameter']) def run(self): self.setup() self.solver.set_loss_function(self.blackbox_function) self.solver.run() self.test() def get_results(self): return self.solver.get_results() @abc.abstractmethod def setup(self): raise NotImplementedError('the user has to implement this function') @abc.abstractmethod def blackbox_function(self): raise NotImplementedError('the user has to implement this function') @abc.abstractmethod def test(self): pass @property def solver(self): return self._solver - @property - def args(self): - return self._args