diff --git a/bin/hyppopy_exe.py b/bin/hyppopy_exe.py index 8bf5103..ef0610e 100644 --- a/bin/hyppopy_exe.py +++ b/bin/hyppopy_exe.py @@ -1,69 +1,69 @@ #!/usr/bin/env python # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from hyppopy.projectmanager import ProjectManager from hyppopy.workflows.unet_usecase.unet_usecase import unet_usecase from hyppopy.workflows.svc_usecase.svc_usecase import svc_usecase from hyppopy.workflows.randomforest_usecase.randomforest_usecase import randomforest_usecase import os import sys import argparse def print_warning(msg): print("\n!!!!! WARNING !!!!!") print(msg) sys.exit() def args_check(args): if not args.workflow: print_warning("No workflow specified, check --help") if not args.config: print_warning("Missing config parameter, check --help") if not os.path.isfile(args.config): print_warning(f"Couldn't find configfile ({args.config}), please check your input --config") if __name__ == "__main__": parser = argparse.ArgumentParser(description='UNet Hyppopy UseCase Example Optimization.') parser.add_argument('-w', '--workflow', type=str, help='workflow to be executed') parser.add_argument('-c', '--config', type=str, help='config filename, .xml or .json formats are supported.' 'pass a full path filename or the filename only if the' 'configfile is in the data folder') args = parser.parse_args() args_check(args) ProjectManager.read_config(args.config) if args.workflow == "svc_usecase": uc = svc_usecase() elif args.workflow == "randomforest_usecase": uc = randomforest_usecase() elif args.workflow == "unet_usecase": uc = unet_usecase() else: - print(f"No workflow called {args.workflow} found!") + print("No workflow called {} found!".format(args.workflow)) sys.exit() uc.run() print(uc.get_results()) diff --git a/hyppopy/deepdict.py b/hyppopy/deepdict.py index f8e84d5..8c19412 100644 --- a/hyppopy/deepdict.py +++ b/hyppopy/deepdict.py @@ -1,437 +1,437 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import re import json import types import pprint import xmltodict from dicttoxml import dicttoxml from collections import OrderedDict import logging LOG = logging.getLogger('hyppopy') from hyppopy.globals import DEEPDICT_XML_ROOT def convert_ordered2std_dict(obj): """ Helper function converting an OrderedDict into a standard lib dict. :param obj: [OrderedDict] """ for key, value in obj.items(): if isinstance(value, OrderedDict): obj[key] = dict(obj[key]) convert_ordered2std_dict(obj[key]) def check_dir_existance(dirname): """ Helper function to check if a directory exists, creating it if not. :param dirname: [str] full path of the directory to check """ if not os.path.exists(dirname): os.mkdir(dirname) class DeepDict(object): """ The DeepDict class represents a nested dictionary with additional functionality compared to a standard lib dict. The data can be accessed and changed vie a pathlike access and dumped or read to .json/.xml files. Initializing instances using defaults creates an empty DeepDict. Using in_data enables to initialize the object instance with data, where in_data can be a dict, or a filepath to a json or xml file. Using path sep the appearance of path passing can be changed, a default data access via path would look like my_dd['target/section/path'] with path_sep='.' like so my_dd['target.section.path'] :param in_data: [dict] or [str], input dict or filename :param path_sep: [str] path separator character """ _data = None _sep = "/" def __init__(self, in_data=None, path_sep="/"): self.clear() self._sep = path_sep - LOG.debug(f"path separator is: {self._sep}") + LOG.debug("path separator is: {}".format(self._sep)) if in_data is not None: if isinstance(in_data, str): self.from_file(in_data) elif isinstance(in_data, dict): self.data = in_data def __str__(self): """ Enables print output for class instances, printing the instance data dict using pretty print :return: [str] """ return pprint.pformat(self.data) def __eq__(self, other): """ Overloads the == operator comparing the instance data dictionaries for equality :param other: [DeepDict] rhs :return: [bool] """ return self.data == other.data def __getitem__(self, path): """ Overloads the return of the [] operator for data access. This enables access the DeepDict instance like so: my_dd['target/section/path'] or my_dd[['target','section','path']] :param path: [str] or [list(str)], the path to the target data structure level/content :return: [object] """ return DeepDict.get_from_path(self.data, path, self.sep) def __setitem__(self, path, value=None): """ Overloads the setter for the [] operator for data assignment. :param path: [str] or [list(str)], the path to the target data structure level/content :param value: [object] rhs assignment object """ if isinstance(path, str): path = path.split(self.sep) if not isinstance(path, list) or isinstance(path, tuple): raise IOError("Input Error, expect list[str] type for path") if len(path) < 1: raise IOError("Input Error, missing section strings") if not path[0] in self._data.keys(): if value is not None and len(path) == 1: self._data[path[0]] = value else: self._data[path[0]] = {} tmp = self._data[path[0]] path.pop(0) while True: if len(path) == 0: break if path[0] not in tmp.keys(): if value is not None and len(path) == 1: tmp[path[0]] = value else: tmp[path[0]] = {} tmp = tmp[path[0]] else: tmp = tmp[path[0]] path.pop(0) def __len__(self): return len(self._data) def items(self): return self.data.items() def clear(self): """ clears the instance data """ LOG.debug("clear()") self._data = {} def from_file(self, fname): """ Loads data from file. Currently implemented .json and .xml file reader. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if fname.endswith(".json"): self.read_json(fname) elif fname.endswith(".xml"): self.read_xml(fname) else: LOG.error("Unknown filetype, expect [.json, .xml]") raise NotImplementedError("Unknown filetype, expect [.json, .xml]") def read_json(self, fname): """ Read json file :param fname: [str] input filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if not os.path.isfile(fname): - raise IOError(f"File {fname} not found!") - LOG.debug(f"read_json({fname})") + raise IOError("File {} not found!".format(fname)) + LOG.debug("read_json({})".format(fname)) try: with open(fname, "r") as read_file: self._data = json.load(read_file) DeepDict.value_traverse(self.data, callback=DeepDict.parse_type) except Exception as e: - LOG.error(f"Error while reading json file {fname} or while converting types") - raise IOError("Error while reading json file {fname} or while converting types") + LOG.error("Error while reading json file {} or while converting types".format(fname)) + raise IOError("Error while reading json file {} or while converting types".format(fname)) def read_xml(self, fname): """ Read xml file :param fname: [str] input filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if not os.path.isfile(fname): - raise IOError(f"File {fname} not found!") - LOG.debug(f"read_xml({fname})") + raise IOError("File {} not found!".format(fname)) + LOG.debug("read_xml({})".format(fname)) try: with open(fname, "r") as read_file: xml = "".join(read_file.readlines()) self._data = xmltodict.parse(xml, attr_prefix='') DeepDict.value_traverse(self.data, callback=DeepDict.parse_type) except Exception as e: - msg = f"Error while reading xml file {fname} or while converting types" + msg = "Error while reading xml file {} or while converting types".format(fname) LOG.error(msg) raise IOError(msg) # if written with DeepDict, the xml contains a root node called # deepdict which should beremoved for consistency reasons if DEEPDICT_XML_ROOT in self._data.keys(): self._data = self._data[DEEPDICT_XML_ROOT] self._data = dict(self.data) # convert the orderes dict structure to a default dict for consistency reasons convert_ordered2std_dict(self.data) def to_file(self, fname): """ Write to file, type is determined by checking the filename ending. Currently implemented is writing to json and to xml. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if fname.endswith(".json"): self.write_json(fname) elif fname.endswith(".xml"): self.write_xml(fname) else: - LOG.error(f"Unknown filetype, expect [.json, .xml]") + LOG.error("Unknown filetype, expect [.json, .xml]") raise NotImplementedError("Unknown filetype, expect [.json, .xml]") def write_json(self, fname): """ Dump data to json file. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") check_dir_existance(os.path.dirname(fname)) try: - LOG.debug(f"write_json({fname})") + LOG.debug("write_json({})".format(fname)) with open(fname, "w") as write_file: json.dump(self.data, write_file) except Exception as e: - LOG.error(f"Failed dumping to json file: {fname}") + LOG.error("Failed dumping to json file: {}".format(fname)) raise e def write_xml(self, fname): """ Dump data to json file. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") check_dir_existance(os.path.dirname(fname)) xml = dicttoxml(self.data, custom_root=DEEPDICT_XML_ROOT, attr_type=False) - LOG.debug(f"write_xml({fname})") + LOG.debug("write_xml({})".format(fname)) try: with open(fname, "w") as write_file: write_file.write(xml.decode("utf-8")) except Exception as e: - LOG.error(f"Failed dumping to xml file: {fname}") + LOG.error("Failed dumping to xml file: {}".format(fname)) raise e def has_section(self, section): return DeepDict.has_key(self.data, section) @staticmethod def get_from_path(data, path, sep="/"): """ Implements a nested dict access via a path like string like so path='target/section/path' which is equivalent to my_dict['target']['section']['path']. :param data: [dict] input dictionary :param path: [str] pathlike string :param sep: [str] path separator, default='/' :return: [object] """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for data") raise IOError("Input Error, expect dict type for data") if isinstance(path, str): path = path.split(sep) if not isinstance(path, list) or isinstance(path, tuple): - LOG.error(f"Input Error, expect list[str] type for path: {path}") + LOG.error("Input Error, expect list[str] type for path: {}".format(path)) raise IOError("Input Error, expect list[str] type for path") if not DeepDict.has_key(data, path[-1]): - LOG.error(f"Input Error, section {path[-1]} does not exist in dictionary") - raise IOError(f"Input Error, section {path[-1]} does not exist in dictionary") + LOG.error("Input Error, section {} does not exist in dictionary".format(path[-1])) + raise IOError("Input Error, section {} does not exist in dictionary".format(path[-1])) try: for k in path: data = data[k] except Exception as e: - LOG.error(f"Failed retrieving data from path {path} due to {e}") - raise LookupError(f"Failed retrieving data from path {path} due to {e}") + LOG.error("Failed retrieving data from path {} due to {}".format(path, e)) + raise LookupError("Failed retrieving data from path {} due to {}".format(path, e)) return data @staticmethod def has_key(data, section, already_found=False): """ Checks if input dictionary has a key called section. The already_found parameter is for internal recursion checks. :param data: [dict] input dictionary :param section: [str] key string to search for :param already_found: recursion criteria check :return: [bool] section found """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for obj") raise IOError("Input Error, expect dict type for obj") if not isinstance(section, str): - LOG.error(f"Input Error, expect dict type for obj {section}") - raise IOError(f"Input Error, expect dict type for obj {section}") + LOG.error("Input Error, expect dict type for obj {}".format(section)) + raise IOError("Input Error, expect dict type for obj {}".format(section)) if already_found: return True found = False for key, value in data.items(): if key == section: found = True if isinstance(value, dict): found = DeepDict.has_key(data[key], section, found) return found @staticmethod def value_traverse(data, callback=None): """ Dictionary filter function, walks through the input dict (obj) calling the callback function for each value. The callback function return is assigned the the corresponding dict value. :param data: [dict] input dictionary :param callback: """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for obj") raise IOError("Input Error, expect dict type for obj") if not isinstance(callback, types.FunctionType): LOG.error("Input Error, expect function type for callback") raise IOError("Input Error, expect function type for callback") for key, value in data.items(): if isinstance(value, dict): DeepDict.value_traverse(data[key], callback) else: data[key] = callback(value) def transfer_attrs(self, cls, target_section): items_set = [] def set_item(item): items_set.append(item[0]) setattr(cls, item[0], item[1]) DeepDict.sectionconstraint_item_traverse(self.data, target_section, callback=set_item, section=None) return items_set @staticmethod def sectionconstraint_item_traverse(data, target_section, callback=None, section=None): """ Dictionary filter function, walks through the input dict (obj) calling the callback function for each item. The callback function then is called with the key value pair as tuple input but only for the target section. :param data: [dict] input dictionary :param callback: """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for obj") raise IOError("Input Error, expect dict type for obj") if not isinstance(callback, types.FunctionType): LOG.error("Input Error, expect function type for callback") raise IOError("Input Error, expect function type for callback") for key, value in data.items(): if isinstance(value, dict): DeepDict.sectionconstraint_item_traverse(data[key], target_section, callback, key) else: if target_section == section: callback((key, value)) @staticmethod def item_traverse(data, callback=None): """ Dictionary filter function, walks through the input dict (obj) calling the callback function for each item. The callback function then is called with the key value pair as tuple input. :param data: [dict] input dictionary :param callback: """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for obj") raise IOError("Input Error, expect dict type for obj") if not isinstance(callback, types.FunctionType): LOG.error("Input Error, expect function type for callback") raise IOError("Input Error, expect function type for callback") for key, value in data.items(): if isinstance(value, dict): DeepDict.value_traverse(data[key], callback) else: callback((key, value)) @staticmethod def parse_type(string): """ Type convert input string to float, int, list, tuple or string :param string: [str] input string :return: [T] converted output """ try: a = float(string) try: b = int(string) except ValueError: return float(string) if a == b: return b return a except ValueError: if string.startswith("[") and string.endswith("]"): string = re.sub(' ', '', string) elements = string[1:-1].split(",") li = [] for e in elements: li.append(DeepDict.parse_type(e)) return li elif string.startswith("(") and string.endswith(")"): elements = string[1:-1].split(",") li = [] for e in elements: li.append(DeepDict.parse_type(e)) return tuple(li) return string @property def data(self): return self._data @data.setter def data(self, value): if not isinstance(value, dict): - LOG.error(f"Input Error, expect dict type for value, but got {type(value)}") - raise IOError(f"Input Error, expect dict type for value, but got {type(value)}") + LOG.error("Input Error, expect dict type for value, but got {}".format(type(value))) + raise IOError("Input Error, expect dict type for value, but got {}".format(type(value))) self.clear() self._data = value @property def sep(self): return self._sep @sep.setter def sep(self, value): if not isinstance(value, str): - LOG.error(f"Input Error, expect str type for value, but got {type(value)}") - raise IOError(f"Input Error, expect str type for value, but got {type(value)}") + LOG.error("Input Error, expect str type for value, but got {}".format(type(value))) + raise IOError("Input Error, expect str type for value, but got {}".format(type(value))) self._sep = value diff --git a/hyppopy/plugins/hyperopt_settings_plugin.py b/hyppopy/plugins/hyperopt_settings_plugin.py index 2809ce6..8ce1f40 100644 --- a/hyppopy/plugins/hyperopt_settings_plugin.py +++ b/hyppopy/plugins/hyperopt_settings_plugin.py @@ -1,105 +1,105 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging import numpy as np from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from pprint import pformat try: from hyperopt import hp from yapsy.IPlugin import IPlugin except: LOG.warning("hyperopt package not installed, will ignore this plugin!") print("hyperopt package not installed, will ignore this plugin!") from hyppopy.settingspluginbase import SettingsPluginBase from hyppopy.settingsparticle import SettingsParticle class hyperopt_Settings(SettingsPluginBase, IPlugin): def __init__(self): SettingsPluginBase.__init__(self) LOG.debug("initialized") def convert_parameter(self, input_dict): - LOG.debug(f"convert input parameter\n\n\t{pformat(input_dict)}\n") + LOG.debug("convert input parameter\n\n\t{}\n".format(pformat(input_dict))) solution_space = {} for name, content in input_dict.items(): particle = hyperopt_SettingsParticle(name=name) for key, value in content.items(): if key == 'domain': particle.domain = value elif key == 'data': particle.data = value elif key == 'type': particle.dtype = value solution_space[name] = particle.get() return solution_space class hyperopt_SettingsParticle(SettingsParticle): def __init__(self, name=None, domain=None, dtype=None, data=None): SettingsParticle.__init__(self, name, domain, dtype, data) def convert(self): if self.domain == "uniform": if self.dtype == "float" or self.dtype == "double": return hp.uniform(self.name, self.data[0], self.data[1]) elif self.dtype == "int": data = list(np.arange(int(self.data[0]), int(self.data[1]+1))) return hp.choice(self.name, data) else: - msg = f"cannot convert the type {self.dtype} in domain {self.domain}" + msg = "cannot convert the type {} in domain {}".format(self.dtype, self.domain) LOG.error(msg) raise LookupError(msg) elif self.domain == "loguniform": if self.dtype == "float" or self.dtype == "double": return hp.loguniform(self.name, self.data[0], self.data[1]) else: - msg = f"cannot convert the type {self.dtype} in domain {self.domain}" + msg = "cannot convert the type {} in domain {}".format(self.dtype, self.domain) LOG.error(msg) raise LookupError(msg) elif self.domain == "normal": if self.dtype == "float" or self.dtype == "double": return hp.normal(self.name, self.data[0], self.data[1]) else: - msg = f"cannot convert the type {self.dtype} in domain {self.domain}" + msg = "cannot convert the type {} in domain {}".format(self.dtype, self.domain) LOG.error(msg) raise LookupError(msg) elif self.domain == "categorical": if self.dtype == 'str': return hp.choice(self.name, self.data) elif self.dtype == 'bool': data = [] for elem in self.data: if elem == "true" or elem == "True" or elem == 1 or elem == "1": data .append(True) elif elem == "false" or elem == "False" or elem == 0 or elem == "0": data .append(False) else: - msg = f"cannot convert the type {self.dtype} in domain {self.domain}, unknown bool type value" + msg = "cannot convert the type {} in domain {}, unknown bool type value".format(self.dtype, self.domain) LOG.error(msg) raise LookupError(msg) return hp.choice(self.name, data) diff --git a/hyppopy/plugins/hyperopt_solver_plugin.py b/hyppopy/plugins/hyperopt_solver_plugin.py index 9f3f608..a8854ec 100644 --- a/hyppopy/plugins/hyperopt_solver_plugin.py +++ b/hyppopy/plugins/hyperopt_solver_plugin.py @@ -1,69 +1,69 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from pprint import pformat from hyperopt import fmin, tpe, hp, STATUS_OK, STATUS_FAIL, Trials from yapsy.IPlugin import IPlugin from hyppopy.projectmanager import ProjectManager from hyppopy.solverpluginbase import SolverPluginBase class hyperopt_Solver(SolverPluginBase, IPlugin): trials = None best = None def __init__(self): SolverPluginBase.__init__(self) LOG.debug("initialized") def loss_function(self, params): try: loss = self.loss(self.data, params) status = STATUS_OK except Exception as e: - LOG.error(f"execution of self.loss(self.data, params) failed due to:\n {e}") + LOG.error("execution of self.loss(self.data, params) failed due to:\n {}".format(e)) status = STATUS_FAIL return {'loss': loss, 'status': status} def execute_solver(self, parameter): - LOG.debug(f"execute_solver using solution space:\n\n\t{pformat(parameter)}\n") + LOG.debug("execute_solver using solution space:\n\n\t{}\n".format(pformat(parameter))) self.trials = Trials() try: self.best = fmin(fn=self.loss_function, space=parameter, algo=tpe.suggest, max_evals=ProjectManager.max_iterations, trials=self.trials) except Exception as e: - msg = f"internal error in hyperopt.fmin occured. {e}" + msg = "internal error in hyperopt.fmin occured. {}".format(e) LOG.error(msg) raise BrokenPipeError(msg) def convert_results(self): txt = "" solution = dict([(k, v) for k, v in self.best.items() if v is not None]) txt += 'Solution Hyperopt Plugin\n========\n' txt += "\n".join(map(lambda x: "%s \t %s" % (x[0], str(x[1])), solution.items())) txt += "\n" return txt diff --git a/hyppopy/plugins/optunity_settings_plugin.py b/hyppopy/plugins/optunity_settings_plugin.py index da2e9d2..4a9a9d6 100644 --- a/hyppopy/plugins/optunity_settings_plugin.py +++ b/hyppopy/plugins/optunity_settings_plugin.py @@ -1,117 +1,117 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from pprint import pformat try: import optunity from yapsy.IPlugin import IPlugin except: LOG.warning("optunity package not installed, will ignore this plugin!") print("optunity package not installed, will ignore this plugin!") from hyppopy.settingspluginbase import SettingsPluginBase from hyppopy.settingsparticle import SettingsParticle class optunity_Settings(SettingsPluginBase, IPlugin): def __init__(self): SettingsPluginBase.__init__(self) LOG.debug("initialized") def convert_parameter(self, input_dict): - LOG.debug(f"convert input parameter\n\n\t{pformat(input_dict)}\n") + LOG.debug("convert input parameter\n\n\t{}\n".format(pformat(input_dict))) # define function spliting input dict # into categorical and non-categorical def split_categorical(pdict): categorical = {} uniform = {} for name, pset in pdict.items(): for key, value in pset.items(): if key == 'domain' and value == 'categorical': categorical[name] = pset elif key == 'domain': uniform[name] = pset return categorical, uniform solution_space = {} # split input in categorical and non-categorical data cat, uni = split_categorical(input_dict) # build up dictionary keeping all non-categorical data uniforms = {} for key, value in uni.items(): for key2, value2 in value.items(): if key2 == 'data': uniforms[key] = value2 # build nested categorical structure inner_level = uniforms for key, value in cat.items(): tmp = {} tmp2 = {} for key2, value2 in value.items(): if key2 == 'data': for elem in value2: tmp[elem] = inner_level tmp2[key] = tmp inner_level = tmp2 solution_space = tmp2 return solution_space # class optunity_SettingsParticle(SettingsParticle): # # def __init__(self, name=None, domain=None, dtype=None, data=None): # SettingsParticle.__init__(self, name, domain, dtype, data) # # def convert(self): # if self.domain == "uniform": # if self.dtype == "float" or self.dtype == "double": # pass # elif self.dtype == "int": # pass # else: # msg = f"cannot convert the type {self.dtype} in domain {self.domain}" # LOG.error(msg) # raise LookupError(msg) # elif self.domain == "loguniform": # if self.dtype == "float" or self.dtype == "double": # pass # else: # msg = f"cannot convert the type {self.dtype} in domain {self.domain}" # LOG.error(msg) # raise LookupError(msg) # elif self.domain == "normal": # if self.dtype == "float" or self.dtype == "double": # pass # else: # msg = f"cannot convert the type {self.dtype} in domain {self.domain}" # LOG.error(msg) # raise LookupError(msg) # elif self.domain == "categorical": # if self.dtype == 'str': # pass # elif self.dtype == 'bool': # pass diff --git a/hyppopy/plugins/optunity_solver_plugin.py b/hyppopy/plugins/optunity_solver_plugin.py index c5e47f6..6416920 100644 --- a/hyppopy/plugins/optunity_solver_plugin.py +++ b/hyppopy/plugins/optunity_solver_plugin.py @@ -1,71 +1,71 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from pprint import pformat import optunity from yapsy.IPlugin import IPlugin from hyppopy.projectmanager import ProjectManager from hyppopy.solverpluginbase import SolverPluginBase class optunity_Solver(SolverPluginBase, IPlugin): solver_info = None trials = None best = None status = None def __init__(self): SolverPluginBase.__init__(self) LOG.debug("initialized") def loss_function(self, **params): try: loss = self.loss(self.data, params) self.status.append('ok') return loss except Exception as e: - LOG.error(f"computing loss failed due to:\n {e}") + LOG.error("computing loss failed due to:\n {}".format(e)) self.status.append('fail') return 1e9 def execute_solver(self, parameter): - LOG.debug(f"execute_solver using solution space:\n\n\t{pformat(parameter)}\n") + LOG.debug("execute_solver using solution space:\n\n\t{}\n".format(pformat(parameter))) self.status = [] try: self.best, self.trials, self.solver_info = optunity.minimize_structured(f=self.loss_function, num_evals=ProjectManager.max_iterations, search_space=parameter) except Exception as e: - LOG.error(f"internal error in optunity.minimize_structured occured. {e}") - raise BrokenPipeError(f"internal error in optunity.minimize_structured occured. {e}") + LOG.error("internal error in optunity.minimize_structured occured. {}".format(e)) + raise BrokenPipeError("internal error in optunity.minimize_structured occured. {}".format(e)) def convert_results(self): solution = dict([(k, v) for k, v in self.best.items() if v is not None]) txt = "" txt += 'Solution Optunity Plugin\n========\n' txt += "\n".join(map(lambda x: "%s \t %s" % (x[0], str(x[1])), solution.items())) - txt += f"\nSolver used: {self.solver_info['solver_name']}" - txt += f"\nOptimum: {self.trials.optimum}" - txt += f"\nIterations used: {self.trials.stats['num_evals']}" - txt += f"\nDuration: {self.trials.stats['time']} s\n" + txt += "\nSolver used: {}".format(self.solver_info['solver_name']) + txt += "\nOptimum: {}".format(self.trials.optimum) + txt += "\nIterations used: {}".format(self.trials.stats['num_evals']) + txt += "\nDuration: {} s\n".format(self.trials.stats['time']) return txt diff --git a/hyppopy/projectmanager.py b/hyppopy/projectmanager.py index a85ece0..3eab4e3 100644 --- a/hyppopy/projectmanager.py +++ b/hyppopy/projectmanager.py @@ -1,111 +1,111 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from hyppopy.singleton import * from hyppopy.deepdict import DeepDict from hyppopy.globals import SETTINGSCUSTOMPATH, SETTINGSSOLVERPATH import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) @singleton_object class ProjectManager(metaclass=Singleton): def __init__(self): self.configfilename = None self.config = None self._extmembers = [] def clear(self): self.configfilename = None self.config = None self.remove_externals() def is_ready(self): return self.config is not None def remove_externals(self): for added in self._extmembers: if added in self.__dict__.keys(): del self.__dict__[added] self._extmembers = [] def get_hyperparameter(self): return self.config["hyperparameter"] def test_config(self): if not isinstance(self.config, DeepDict): - msg = f"test_config failed, config is not of type DeepDict" + msg = "test_config failed, config is not of type DeepDict" LOG.error(msg) return False sections = ["hyperparameter"] sections += SETTINGSSOLVERPATH.split("/") sections += SETTINGSCUSTOMPATH.split("/") for sec in sections: if not self.config.has_section(sec): - msg = f"test_config failed, config has no section {sec}" + msg = "test_config failed, config has no section {}".format(sec) LOG.error(msg) return False return True def set_config(self, config): self.clear() if isinstance(config, dict): self.config = DeepDict() self.config.data = config elif isinstance(config, DeepDict): self.config = config else: - msg = f"unknown type ({type(config)}) for config passed, expected dict or DeepDict" + msg = "unknown type ({}) for config passed, expected dict or DeepDict".format(type(config)) LOG.error(msg) raise IOError(msg) if not self.test_config(): self.clear() return False try: self._extmembers += self.config.transfer_attrs(self, SETTINGSCUSTOMPATH.split("/")[-1]) self._extmembers += self.config.transfer_attrs(self, SETTINGSSOLVERPATH.split("/")[-1]) except Exception as e: - msg = f"transfering custom section as class attributes failed, " \ - f"is the config path to your custom section correct? {SETTINGSCUSTOMPATH}. Exception {e}" + msg = "transfering custom section as class attributes failed, " \ + "is the config path to your custom section correct? {}. Exception {}".format(SETTINGSCUSTOMPATH, e) LOG.error(msg) raise LookupError(msg) return True def read_config(self, configfile): self.clear() self.configfilename = configfile self.config = DeepDict(configfile) if not self.test_config(): self.clear() return False try: self._extmembers += self.config.transfer_attrs(self, SETTINGSCUSTOMPATH.split("/")[-1]) self._extmembers += self.config.transfer_attrs(self, SETTINGSSOLVERPATH.split("/")[-1]) except Exception as e: - msg = f"transfering custom section as class attributes failed, " \ - f"is the config path to your custom section correct? {SETTINGSCUSTOMPATH}. Exception {e}" + msg = "transfering custom section as class attributes failed, " \ + "is the config path to your custom section correct? {}. Exception {e}".format(SETTINGSCUSTOMPATH, e) LOG.error(msg) raise LookupError(msg) return True diff --git a/hyppopy/settingspluginbase.py b/hyppopy/settingspluginbase.py index 5403acb..f732e58 100644 --- a/hyppopy/settingspluginbase.py +++ b/hyppopy/settingspluginbase.py @@ -1,78 +1,78 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import abc import os import copy import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from hyppopy.globals import SETTINGSSOLVERPATH, SETTINGSCUSTOMPATH from hyppopy.deepdict import DeepDict class SettingsPluginBase(object): _data = None _name = None def __init__(self): self._data = {} @abc.abstractmethod def convert_parameter(self): raise NotImplementedError('users must define convert_parameter to use this base class') def get_hyperparameter(self): return self.convert_parameter(self.data) def set_hyperparameter(self, input_data): self.data.clear() self.data = copy.deepcopy(input_data) def read(self, fname): self.data.clear() self.data.from_file(fname) def write(self, fname): self.data.to_file(fname) @property def data(self): return self._data @data.setter def data(self, value): if isinstance(value, dict): self._data = value elif isinstance(value, DeepDict): self._data = value.data else: - raise IOError(f"unexpected input type({type(value)}) for data, needs to be of type dict or DeepDict!") + raise IOError("unexpected input type({}) for data, needs to be of type dict or DeepDict!".format(type(value))) @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): - LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead") - raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead") + LOG.error("Invalid input, str type expected for value, got {} instead".format(type(value))) + raise IOError("Invalid input, str type expected for value, got {} instead".format(type(value))) self._name = value diff --git a/hyppopy/solver.py b/hyppopy/solver.py index de02af9..08e7d83 100644 --- a/hyppopy/solver.py +++ b/hyppopy/solver.py @@ -1,85 +1,86 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from hyppopy.projectmanager import ProjectManager import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class Solver(object): _name = None _solver_plugin = None _settings_plugin = None def __init__(self): pass def set_data(self, data): self.solver.set_data(data) def set_hyperparameters(self, params): self.settings.set_hyperparameter(params) def set_loss_function(self, loss_func): self.solver.set_loss_function(loss_func) def run(self): if not ProjectManager.is_ready(): LOG.error("No config data found to initialize PluginSetting object") raise IOError("No config data found to initialize PluginSetting object") hyps = ProjectManager.get_hyperparameter() self.settings.set_hyperparameter(hyps) self.solver.settings = self.settings self.solver.run() def get_results(self): return self.solver.get_results() @property def is_ready(self): return self.solver is not None and self.settings is not None @property def solver(self): return self._solver_plugin @solver.setter def solver(self, value): self._solver_plugin = value @property def settings(self): return self._settings_plugin @settings.setter def settings(self, value): self._settings_plugin = value @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): - LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead") - raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead") + msg = "Invalid input, str type expected for value, got {} instead".format(type(value)) + LOG.error(msg) + raise IOError(msg) self._name = value diff --git a/hyppopy/solverfactory.py b/hyppopy/solverfactory.py index fad3f74..17ff3e8 100644 --- a/hyppopy/solverfactory.py +++ b/hyppopy/solverfactory.py @@ -1,155 +1,158 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from yapsy.PluginManager import PluginManager from hyppopy.projectmanager import ProjectManager from hyppopy.globals import PLUGIN_DEFAULT_DIR from hyppopy.deepdict import DeepDict from hyppopy.solver import Solver from hyppopy.singleton import * import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) @singleton_object class SolverFactory(metaclass=Singleton): """ This class is responsible for grabbing all plugins from the plugin folder arranging them into a Solver class instances. These Solver class instances can be requested from the factory via the get_solver method. The SolverFactory class is a Singleton class, so try not to instantiate it using SolverFactory(), the consequences will be horrific. Instead use is like a class having static functions only, SolverFactory.method(). """ _plugin_dirs = [] _plugins = {} def __init__(self): print("Solverfactory: I'am alive!") self.reset() self.load_plugins() LOG.debug("Solverfactory initialized") def load_plugins(self): """ Load plugin modules from plugin paths """ LOG.debug("load_plugins()") manager = PluginManager() - LOG.debug(f"setPluginPlaces(" + " ".join(map(str, self._plugin_dirs))) + LOG.debug("setPluginPlaces(" + " ".join(map(str, self._plugin_dirs))) manager.setPluginPlaces(self._plugin_dirs) manager.collectPlugins() for plugin in manager.getAllPlugins(): name_elements = plugin.plugin_object.__class__.__name__.split("_") LOG.debug("found plugin " + " ".join(map(str, name_elements))) if len(name_elements) != 2 or ("Solver" not in name_elements and "Settings" not in name_elements): - LOG.error(f"invalid plugin class naming for class {plugin.plugin_object.__class__.__name__}, the convention is libname_Solver or libname_Settings.") - raise NameError(f"invalid plugin class naming for class {plugin.plugin_object.__class__.__name__}, the convention is libname_Solver or libname_Settings.") + msg = "invalid plugin class naming for class {}, the convention is libname_Solver or libname_Settings.".format(plugin.plugin_object.__class__.__name__) + LOG.error(msg) + raise NameError(msg) if name_elements[0] not in self._plugins.keys(): self._plugins[name_elements[0]] = Solver() self._plugins[name_elements[0]].name = name_elements[0] if name_elements[1] == "Solver": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] self._plugins[name_elements[0]].solver = obj - LOG.info(f"plugin: {name_elements[0]} Solver loaded") + LOG.info("plugin: {} Solver loaded".format(name_elements[0])) except Exception as e: - LOG.error(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") - raise ImportError(f"Failed to instanciate class {plugin.plugin_object.__class__.__name__}") + msg = "failed to instanciate class {}".format(plugin.plugin_object.__class__.__name__) + LOG.error(msg) + raise ImportError(msg) elif name_elements[1] == "Settings": try: obj = plugin.plugin_object.__class__() obj.name = name_elements[0] self._plugins[name_elements[0]].settings = obj - LOG.info(f"plugin: {name_elements[0]} ParameterSpace loaded") + LOG.info("plugin: {} ParameterSpace loaded".format(name_elements[0])) except Exception as e: - LOG.error(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") - raise ImportError(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}") + msg = "failed to instanciate class {}".format(plugin.plugin_object.__class__.__name__) + LOG.error(msg) + raise ImportError(msg) else: - LOG.error(f"failed loading plugin {name_elements[0]}, please check if naming conventions are kept!") - raise IOError(f"failed loading plugin {name_elements[0]}!, please check if naming conventions are kept!") + msg = "failed loading plugin {}, please check if naming conventions are kept!".format(name_elements[0]) + LOG.error(msg) + raise IOError(msg) if len(self._plugins) == 0: msg = "no plugins found, please check your plugin folder names or your plugin scripts for errors!" LOG.error(msg) raise IOError(msg) def reset(self): """ Reset solver factory """ LOG.debug("reset()") self._plugins = {} self._plugin_dirs = [] self.add_plugin_dir(os.path.abspath(PLUGIN_DEFAULT_DIR)) def add_plugin_dir(self, dir): """ Add plugin directory """ - LOG.debug(f"add_plugin_dir({dir})") + LOG.debug("add_plugin_dir({})".format(dir)) self._plugin_dirs.append(dir) def list_solver(self): """ list all solvers available :return: [list(str)] """ return list(self._plugins.keys()) def from_settings(self, settings): if isinstance(settings, str): if not os.path.isfile(settings): - LOG.error(f"input error, file {settings} not found!") + LOG.error("input error, file {} not found!".format(settings)) if not ProjectManager.read_config(settings): LOG.error("failed to read config in ProjectManager!") return None else: if not ProjectManager.set_config(settings): LOG.error("failed to set config in ProjectManager!") return None if not ProjectManager.is_ready(): LOG.error("failed to set config in ProjectManager!") return None try: solver = self.get_solver(ProjectManager.use_plugin) except Exception as e: - msg = f"failed to create solver, reason {e}" + msg = "failed to create solver, reason {}".format(e) LOG.error(msg) return None return solver def get_solver(self, name): """ returns a solver by name tag :param name: [str] solver name :return: [Solver] instance """ if not isinstance(name, str): - msg = f"Invalid input, str type expected for name, got {type(name)} instead" + msg = "Invalid input, str type expected for name, got {} instead".format(type(name)) LOG.error(msg) raise IOError(msg) if name not in self.list_solver(): - msg = f"failed solver request, a solver called {name} is not available, " \ - f"check for typo or if your plugin failed while loading!" + msg = "failed solver request, a solver called {} is not available, check for typo or if your plugin failed while loading!".format(name) LOG.error(msg) raise LookupError(msg) - LOG.debug(f"get_solver({name})") + LOG.debug("get_solver({})".format(name)) return self._plugins[name] diff --git a/hyppopy/solverpluginbase.py b/hyppopy/solverpluginbase.py index e880e80..10944a3 100644 --- a/hyppopy/solverpluginbase.py +++ b/hyppopy/solverpluginbase.py @@ -1,84 +1,84 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import abc import os import logging from hyppopy.globals import DEBUGLEVEL from hyppopy.settingspluginbase import SettingsPluginBase LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class SolverPluginBase(object): data = None loss = None _settings = None _name = None def __init__(self): pass @abc.abstractmethod def loss_function(self, params): raise NotImplementedError('users must define loss_func to use this base class') @abc.abstractmethod def execute_solver(self): raise NotImplementedError('users must define execute_solver to use this base class') @abc.abstractmethod def convert_results(self): raise NotImplementedError('users must define convert_results to use this base class') def set_data(self, data): self.data = data def set_loss_function(self, func): self.loss = func def get_results(self): return self.convert_results() def run(self): self.execute_solver(self.settings.get_hyperparameter()) @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): - msg = f"Invalid input, str type expected for value, got {type(value)} instead" + msg = "Invalid input, str type expected for value, got {} instead".format(type(value)) LOG.error(msg) raise IOError(msg) self._name = value @property def settings(self): return self._settings @settings.setter def settings(self, value): if not isinstance(value, SettingsPluginBase): - msg = f"Invalid input, SettingsPluginBase type expected for value, got {type(value)} instead" + msg = "Invalid input, SettingsPluginBase type expected for value, got {} instead".format(type(value)) LOG.error(msg) raise IOError(msg) self._settings = value diff --git a/hyppopy/workflows/unet_usecase/unet_uscase_utils.py b/hyppopy/workflows/unet_usecase/unet_uscase_utils.py index 9b9147c..a5028f1 100644 --- a/hyppopy/workflows/unet_usecase/unet_uscase_utils.py +++ b/hyppopy/workflows/unet_usecase/unet_uscase_utils.py @@ -1,417 +1,417 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import torch import pickle import fnmatch import numpy as np from torch import nn from medpy.io import load from collections import defaultdict from abc import ABCMeta, abstractmethod def sum_tensor(input, axes, keepdim=False): axes = np.unique(axes).astype(int) if keepdim: for ax in axes: input = input.sum(int(ax), keepdim=True) else: for ax in sorted(axes, reverse=True): input = input.sum(int(ax)) return input def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None): if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]: rebalance_weights = rebalance_weights[1:] # this is the case when use_bg=False axes = tuple([0] + list(range(2, len(net_output.size())))) tp = sum_tensor(net_output * gt, axes, keepdim=False) fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False) fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False) weights = torch.ones(tp.shape) weights[0] = background_weight if net_output.device.type == "cuda": weights = weights.cuda(net_output.device.index) if rebalance_weights is not None: rebalance_weights = torch.from_numpy(rebalance_weights).float() if net_output.device.type == "cuda": rebalance_weights = rebalance_weights.cuda(net_output.device.index) tp = tp * rebalance_weights fn = fn * rebalance_weights result = (- ((2 * tp + smooth_in_nom) / (2 * tp + fp + fn + smooth)) * weights).mean() return result def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1.): axes = tuple(range(2, len(net_output.size()))) intersect = sum_tensor(net_output * gt, axes, keepdim=False) denom = sum_tensor(net_output + gt, axes, keepdim=False) result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth)) * weights).mean() #TODO: Was ist weights and er Stelle? return result class SoftDiceLoss(nn.Module): def __init__(self, smooth=1., apply_nonlin=None, batch_dice=False, do_bg=True, smooth_in_nom=True, background_weight=1, rebalance_weights=None): """ hahaa no documentation for you today :param smooth: :param apply_nonlin: :param batch_dice: :param do_bg: :param smooth_in_nom: :param background_weight: :param rebalance_weights: """ super(SoftDiceLoss, self).__init__() if not do_bg: assert background_weight == 1, "if there is no bg, then set background weight to 1 you dummy" self.rebalance_weights = rebalance_weights self.background_weight = background_weight self.smooth_in_nom = smooth_in_nom self.do_bg = do_bg self.batch_dice = batch_dice self.apply_nonlin = apply_nonlin self.smooth = smooth self.y_onehot = None if not smooth_in_nom: self.nom_smooth = 0 else: self.nom_smooth = smooth def forward(self, x, y): with torch.no_grad(): y = y.long() shp_x = x.shape shp_y = y.shape if self.apply_nonlin is not None: x = self.apply_nonlin(x) if len(shp_x) != len(shp_y): y = y.view((shp_y[0], 1, *shp_y[1:])) # now x and y should have shape (B, C, X, Y(, Z))) and (B, 1, X, Y(, Z))), respectively y_onehot = torch.zeros(shp_x) if x.device.type == "cuda": y_onehot = y_onehot.cuda(x.device.index) y_onehot.scatter_(1, y, 1) if not self.do_bg: x = x[:, 1:] y_onehot = y_onehot[:, 1:] if not self.batch_dice: if self.background_weight != 1 or (self.rebalance_weights is not None): raise NotImplementedError("nah son") l = soft_dice(x, y_onehot, self.smooth, self.smooth_in_nom) else: l = soft_dice_per_batch_2(x, y_onehot, self.smooth, self.smooth_in_nom, background_weight=self.background_weight, rebalance_weights=self.rebalance_weights) return l def load_dataset(base_dir, pattern='*.npy', slice_offset=5, keys=None): fls = [] files_len = [] slices_ax = [] for root, dirs, files in os.walk(base_dir): i = 0 for filename in sorted(fnmatch.filter(files, pattern)): if keys is not None and filename[:-4] in keys: npy_file = os.path.join(root, filename) numpy_array = np.load(npy_file, mmap_mode="r") fls.append(npy_file) files_len.append(numpy_array.shape[1]) slices_ax.extend([(i, j) for j in range(slice_offset, files_len[-1] - slice_offset)]) i += 1 return fls, files_len, slices_ax, class SlimDataLoaderBase(object): def __init__(self, data, batch_size, number_of_threads_in_multithreaded=None): """ Slim version of DataLoaderBase (which is now deprecated). Only provides very simple functionality. You must derive from this class to implement your own DataLoader. You must overrive self.generate_train_batch() If you use our MultiThreadedAugmenter you will need to also set and use number_of_threads_in_multithreaded. See multithreaded_dataloading in examples! :param data: will be stored in self._data. You can use it to generate your batches in self.generate_train_batch() :param batch_size: will be stored in self.batch_size for use in self.generate_train_batch() :param number_of_threads_in_multithreaded: will be stored in self.number_of_threads_in_multithreaded. None per default. If you wish to iterate over all your training data only once per epoch, you must coordinate your Dataloaders and you will need this information """ __metaclass__ = ABCMeta self.number_of_threads_in_multithreaded = number_of_threads_in_multithreaded self._data = data self.batch_size = batch_size self.thread_id = 0 def set_thread_id(self, thread_id): self.thread_id = thread_id def __iter__(self): return self def __next__(self): return self.generate_train_batch() @abstractmethod def generate_train_batch(self): '''override this Generate your batch from self._data .Make sure you generate the correct batch size (self.BATCH_SIZE) ''' pass class NumpyDataLoader(SlimDataLoaderBase): def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, file_pattern='*.npy', label_slice=1, input_slice=(0,), keys=None): self.files, self.file_len, self.slices = load_dataset(base_dir=base_dir, pattern=file_pattern, slice_offset=0, keys=keys, ) super(NumpyDataLoader, self).__init__(self.slices, batch_size, num_batches) self.batch_size = batch_size self.use_next = False if mode == "train": self.use_next = False self.slice_idxs = list(range(0, len(self.slices))) self.data_len = len(self.slices) self.num_batches = min((self.data_len // self.batch_size)+10, num_batches) if isinstance(label_slice, int): label_slice = (label_slice,) self.input_slice = input_slice self.label_slice = label_slice self.np_data = np.asarray(self.slices) def reshuffle(self): print("Reshuffle...") random.shuffle(self.slice_idxs) print("Initializing... this might take a while...") def generate_train_batch(self): open_arr = random.sample(self._data, self.batch_size) return self.get_data_from_array(open_arr) def __len__(self): n_items = min(self.data_len // self.batch_size, self.num_batches) return n_items def __getitem__(self, item): slice_idxs = self.slice_idxs data_len = len(self.slices) np_data = self.np_data if item > len(self): raise StopIteration() if (item * self.batch_size) == data_len: raise StopIteration() start_idx = (item * self.batch_size) % data_len stop_idx = ((item + 1) * self.batch_size) % data_len if ((item + 1) * self.batch_size) == data_len: stop_idx = data_len if stop_idx > start_idx: idxs = slice_idxs[start_idx:stop_idx] else: raise StopIteration() open_arr = np_data[idxs] return self.get_data_from_array(open_arr) def get_data_from_array(self, open_array): data = [] fnames = [] slice_idxs = [] labels = [] for slice in open_array: fn_name = self.files[slice[0]] numpy_array = np.load(fn_name, mmap_mode="r") numpy_slice = numpy_array[ :, slice[1], ] data.append(numpy_slice[None, self.input_slice[0]]) # 'None' keeps the dimension if self.label_slice is not None: labels.append(numpy_slice[None, self.label_slice[0]]) # 'None' keeps the dimension fnames.append(self.files[slice[0]]) slice_idxs.append(slice[1]) ret_dict = {'data': np.asarray(data), 'fnames': fnames, 'slice_idxs': slice_idxs} if self.label_slice is not None: ret_dict['seg'] = np.asarray(labels) return ret_dict class NumpyDataSet(object): """ TODO """ def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, num_processes=8, num_cached_per_queue=8 * 4, target_size=128, file_pattern='*.npy', label_slice=1, input_slice=(0,), do_reshuffle=True, keys=None): data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, seed=seed, file_pattern=file_pattern, input_slice=input_slice, label_slice=label_slice, keys=keys) self.data_loader = data_loader self.batch_size = batch_size self.do_reshuffle = do_reshuffle self.number_of_slices = 1 self.transforms = get_transforms(mode=mode, target_size=target_size) self.augmenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=num_processes, num_cached_per_queue=num_cached_per_queue, seeds=seed, shuffle=do_reshuffle) self.augmenter.restart() def __len__(self): return len(self.data_loader) def __iter__(self): if self.do_reshuffle: self.data_loader.reshuffle() self.augmenter.renew() return self.augmenter def __next__(self): return next(self.augmenter) def reshape(orig_img, append_value=-1024, new_shape=(512, 512, 512)): reshaped_image = np.zeros(new_shape) reshaped_image[...] = append_value x_offset = 0 y_offset = 0 # (new_shape[1] - orig_img.shape[1]) // 2 z_offset = 0 # (new_shape[2] - orig_img.shape[2]) // 2 reshaped_image[x_offset:orig_img.shape[0] + x_offset, y_offset:orig_img.shape[1] + y_offset, z_offset:orig_img.shape[2] + z_offset] = orig_img # insert temp_img.min() as background value return reshaped_image def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): if join: l = os.path.join else: l = lambda x, y: y res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) and (prefix is None or i.startswith(prefix)) and (suffix is None or i.endswith(suffix))] if sort: res.sort() return res def preprocess_data(root_dir): print("preprocess data...") image_dir = os.path.join(root_dir, 'imagesTr') - print(f"image_dir: {image_dir}") + print("image_dir: {}".format(image_dir)) label_dir = os.path.join(root_dir, 'labelsTr') - print(f"label_dir: {label_dir}") + print("label_dir: {}".format(label_dir)) output_dir = os.path.join(root_dir, 'preprocessed') - print(f"output_dir: {output_dir} ... ", end="") + print("output_dir: {} ... ".format(output_dir), end="") classes = 3 if not os.path.exists(output_dir): os.makedirs(output_dir) print("created!") else: print("found!\npreprocessed data already available, aborted preprocessing!") return False print("start preprocessing ... ", end="") class_stats = defaultdict(int) total = 0 nii_files = subfiles(image_dir, suffix=".nii.gz", join=False) for i in range(0, len(nii_files)): if nii_files[i].startswith("._"): nii_files[i] = nii_files[i][2:] for i, f in enumerate(nii_files): image, _ = load(os.path.join(image_dir, f)) label, _ = load(os.path.join(label_dir, f.replace('_0000', ''))) for i in range(classes): class_stats[i] += np.sum(label == i) total += np.sum(label == i) image = (image - image.min()) / (image.max() - image.min()) image = reshape(image, append_value=0, new_shape=(64, 64, 64)) label = reshape(label, append_value=0, new_shape=(64, 64, 64)) result = np.stack((image, label)) np.save(os.path.join(output_dir, f.split('.')[0] + '.npy'), result) print("finished!") return True def create_splits(output_dir, image_dir): print("creating splits ... ", end="") npy_files = subfiles(image_dir, suffix=".npy", join=False) trainset_size = len(npy_files) * 50 // 100 valset_size = len(npy_files) * 25 // 100 testset_size = len(npy_files) * 25 // 100 splits = [] for split in range(0, 5): image_list = npy_files.copy() trainset = [] valset = [] testset = [] for i in range(0, trainset_size): patient = np.random.choice(image_list) image_list.remove(patient) trainset.append(patient[:-4]) for i in range(0, valset_size): patient = np.random.choice(image_list) image_list.remove(patient) valset.append(patient[:-4]) for i in range(0, testset_size): patient = np.random.choice(image_list) image_list.remove(patient) testset.append(patient[:-4]) split_dict = dict() split_dict['train'] = trainset split_dict['val'] = valset split_dict['test'] = testset splits.append(split_dict) with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f: pickle.dump(splits, f) print("finished!") diff --git a/hyppopy/workflows/unet_usecase/unet_usecase.py b/hyppopy/workflows/unet_usecase/unet_usecase.py index 545a97c..3e6046d 100644 --- a/hyppopy/workflows/unet_usecase/unet_usecase.py +++ b/hyppopy/workflows/unet_usecase/unet_usecase.py @@ -1,132 +1,132 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import torch import numpy as np import pandas as pd from sklearn.svm import SVC import torch.optim as optim import torch.nn.functional as F from .networks.RecursiveUNet import UNet from .loss_functions.dice_loss import SoftDiceLoss from sklearn.model_selection import cross_val_score from torch.optim.lr_scheduler import ReduceLROnPlateau from .datasets.two_dim.NumpyDataLoader import NumpyDataSet from hyppopy.projectmanager import ProjectManager from hyppopy.workflows.workflowbase import WorkflowBase from hyppopy.workflows.dataloader.unetloader import UnetDataLoader class unet_usecase(WorkflowBase): def setup(self): dl = UnetDataLoader() dl.start(data_path=ProjectManager.data_path, data_name=ProjectManager.data_name, image_dir=ProjectManager.image_dir, labels_dir=ProjectManager.labels_dir, split_dir=ProjectManager.split_dir, output_dir=ProjectManager.data_path, num_classes=ProjectManager.num_classes) self.solver.set_data(dl.data) def blackbox_function(self, data, params): if "batch_size" in params.keys(): params["batch_size"] = int(round(params["batch_size"])) if "batch_size" in params.keys(): params["batch_size"] = int(round(params["batch_size"])) if "n_epochs" in params.keys(): params["n_epochs"] = int(round(params["n_epochs"])) batch_size = 8 patch_size = 64 tr_keys = data[ProjectManager.fold]['train'] val_keys = data[ProjectManager.fold]['val'] data_dir = os.path.join(ProjectManager.data_path, *(ProjectManager.data_name, ProjectManager.preprocessed_dir)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_data_loader = NumpyDataSet(data_dir, target_size=patch_size, batch_size=batch_size, keys=tr_keys) val_data_loader = NumpyDataSet(data_dir, target_size=patch_size, batch_size=batch_size, keys=val_keys, mode="val", do_reshuffle=False) model = UNet(num_classes=ProjectManager.num_classes, in_channels=ProjectManager.in_channels) model.to(device) # We use a combination of DICE-loss and CE-Loss in this example. # This proved good in the medical segmentation decathlon. dice_loss = SoftDiceLoss(batch_dice=True) # Softmax für DICE Loss! ce_loss = torch.nn.CrossEntropyLoss() # Kein Softmax für CE Loss -> ist in torch schon mit drin! optimizer = optim.Adam(model.parameters(), lr=params["learning_rate"]) scheduler = ReduceLROnPlateau(optimizer, 'min') losses = [] - print(f"n_epochs {params['n_epochs']}") + print("n_epochs {}".format(params['n_epochs'])) for epoch in range(params["n_epochs"]): #### Train #### model.train() data = None batch_counter = 0 for data_batch in train_data_loader: optimizer.zero_grad() # Shape of data_batch = [1, b, c, w, h] # Desired shape = [b, c, w, h] # Move data and target to the GPU data = data_batch['data'][0].float().to(device) target = data_batch['seg'][0].long().to(device) pred = model(data) pred_softmax = F.softmax(pred, dim=1) # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally. loss = dice_loss(pred_softmax, target.squeeze()) + ce_loss(pred, target.squeeze()) loss.backward() optimizer.step() batch_counter += 1 ############### #### Validate #### model.eval() data = None loss_list = [] with torch.no_grad(): for data_batch in val_data_loader: data = data_batch['data'][0].float().to(device) target = data_batch['seg'][0].long().to(device) pred = model(data) pred_softmax = F.softmax(pred) # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally. loss = dice_loss(pred_softmax, target.squeeze()) + ce_loss(pred, target.squeeze()) loss_list.append(loss.item()) assert data is not None, 'data is None. Please check if your dataloader works properly' scheduler.step(np.mean(loss_list)) losses.append(np.mean(loss_list)) ################## return np.mean(losses)