diff --git a/bin/hyppopy_exe.py b/bin/hyppopy_exe.py index a69d792..2296c38 100644 --- a/bin/hyppopy_exe.py +++ b/bin/hyppopy_exe.py @@ -1,23 +1,85 @@ +#!/usr/bin/env python # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) -from hyppopy.cmdtools import * +from hyppopy.workflows.unet_usecase import unet_usecase +from hyppopy.workflows.svc_usecase import svc_usecase +from hyppopy.workflows.randomforest_usecase import randomforest_usecase -if __name__ == '__main__': - cmd_workflow() +import os +import sys +import argparse +import hyppopy.solverfactory as sfac + + +solver_factory = sfac.SolverFactory.instance() + + +def print_warning(msg): + print("\n!!!!! WARNING !!!!!") + print(msg) + sys.exit() + + +def args_check(args): + if not args.workflow: + print_warning("No workflow specified, check --help") + if not args.config: + print_warning("Missing config parameter, check --help") + if not args.data: + print_warning("Missing data parameter, check --help") + if not os.path.isdir(args.data): + print_warning("Couldn't find data path, please check your input --data") + + if not os.path.isfile(args.config): + tmp = os.path.join(args.data, args.config) + if not os.path.isfile(tmp): + print_warning("Couldn't find the config file, please check your input --config") + args.config = tmp + if args.plugin not in solver_factory.list_solver(): + print_warning(f"The requested plugin {args.plugin} is not available, please check for typos. Plugin options :" + f"{', '.join(solver_factory.list_solver())}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='UNet Hyppopy UseCase Example Optimization.') + parser.add_argument('-w', '--workflow', type=str, + help='workflow to be executed') + parser.add_argument('-p', '--plugin', type=str, default='hyperopt', + help='plugin to be used default=[hyperopt], optunity') + parser.add_argument('-d', '--data', type=str, help='training data path') + parser.add_argument('-c', '--config', type=str, help='config filename, .xml or .json formats are supported.' + 'pass a full path filename or the filename only if the' + 'configfile is in the data folder') + parser.add_argument('-i', '--iterations', type=int, default=0, + help='number of iterations, default=[0] if set to 0 the value set via configfile is used, ' + 'otherwise the configfile value will be overwritten') + + args = parser.parse_args() + + args_check(args) + + if args.workflow == "svc_usecase": + svc_usecase.svc_usecase(args) + elif args.workflow == "randomforest_usecase": + randomforest_usecase.randomforest_usecase(args) + elif args.workflow == "unet_usecase": + unet_usecase.unet_usecase(args) + else: + print(f"No workflow called {args.workflow} found!") diff --git a/hyppopy/cmdtools.py b/hyppopy/cmdtools.py deleted file mode 100644 index 37151d0..0000000 --- a/hyppopy/cmdtools.py +++ /dev/null @@ -1,68 +0,0 @@ -# -*- coding: utf-8 -*- -# -# DKFZ -# -# -# Copyright (c) German Cancer Research Center, -# Division of Medical and Biological Informatics. -# All rights reserved. -# -# This software is distributed WITHOUT ANY WARRANTY; without -# even the implied warranty of MERCHANTABILITY or FITNESS FOR -# A PARTICULAR PURPOSE. -# -# See LICENSE.txt or http://www.mitk.org for details. -# -# Author: Sven Wanner (s.wanner@dkfz.de) - -import argparse - -from sklearn.svm import SVC -from sklearn import datasets -from sklearn.model_selection import cross_val_score -from sklearn.model_selection import train_test_split - -import logging -LOG = logging.getLogger('hyppopy') - -from hyppopy.solverfactory import SolverFactory - - -def cmd_workflow(): - parser = argparse.ArgumentParser(description="") - - parser.add_argument('-v', '--verbosity', type=int, required=False, default=0, - help='') - - - args_dict = vars(parser.parse_args()) - - iris = datasets.load_iris() - X, X_test, y, y_test = train_test_split(iris.data, iris.target, test_size=0.1, random_state=42) - my_IRIS_dta = [X, y] - - my_SVC_parameter = { - 'C': {'domain': 'uniform', 'data': [0, 20]}, - 'gamma': {'domain': 'uniform', 'data': [0.0001, 20.0]}, - 'kernel': {'domain': 'categorical', 'data': ['linear', 'sigmoid', 'poly', 'rbf']} - } - - def my_SVC_loss_func(data, params): - clf = SVC(**params) - return -cross_val_score(clf, data[0], data[1], cv=3).mean() - - factory = SolverFactory() - - solver = factory.get_solver('optunity') - solver.set_data(my_IRIS_dta) - solver.set_parameters(my_SVC_parameter) - solver.set_loss_function(my_SVC_loss_func) - solver.run() - solver.get_results() - - solver = factory.get_solver('hyperopt') - solver.set_data(my_IRIS_dta) - solver.set_parameters(my_SVC_parameter) - solver.set_loss_function(my_SVC_loss_func) - solver.run() - solver.get_results() diff --git a/hyppopy/deepdict.py b/hyppopy/deepdict.py index b7b910d..9a31c8a 100644 --- a/hyppopy/deepdict.py +++ b/hyppopy/deepdict.py @@ -1,385 +1,386 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import re import json import types import pprint import xmltodict from dicttoxml import dicttoxml from collections import OrderedDict import logging LOG = logging.getLogger('hyppopy') from hyppopy.globals import DEEPDICT_XML_ROOT def convert_ordered2std_dict(obj): """ Helper function converting an OrderedDict into a standard lib dict. :param obj: [OrderedDict] """ for key, value in obj.items(): if isinstance(value, OrderedDict): obj[key] = dict(obj[key]) convert_ordered2std_dict(obj[key]) def check_dir_existance(dirname): """ Helper function to check if a directory exists, creating it if not. :param dirname: [str] full path of the directory to check """ if not os.path.exists(dirname): os.mkdir(dirname) class DeepDict(object): """ The DeepDict class represents a nested dictionary with additional functionality compared to a standard lib dict. The data can be accessed and changed vie a pathlike access and dumped or read to .json/.xml files. Initializing instances using defaults creates an empty DeepDict. Using in_data enables to initialize the object instance with data, where in_data can be a dict, or a filepath to a json or xml file. Using path sep the appearance of path passing can be changed, a default data access via path would look like my_dd['target/section/path'] with path_sep='.' like so my_dd['target.section.path'] :param in_data: [dict] or [str], input dict or filename :param path_sep: [str] path separator character """ _data = None _sep = "/" def __init__(self, in_data=None, path_sep="/"): self.clear() self._sep = path_sep LOG.debug(f"path separator is: {self._sep}") if in_data is not None: if isinstance(in_data, str): self.from_file(in_data) elif isinstance(in_data, dict): self.data = in_data def __str__(self): """ Enables print output for class instances, printing the instance data dict using pretty print :return: [str] """ return pprint.pformat(self.data) def __eq__(self, other): """ Overloads the == operator comparing the instance data dictionaries for equality :param other: [DeepDict] rhs :return: [bool] """ return self.data == other.data def __getitem__(self, path): """ Overloads the return of the [] operator for data access. This enables access the DeepDict instance like so: my_dd['target/section/path'] or my_dd[['target','section','path']] :param path: [str] or [list(str)], the path to the target data structure level/content :return: [object] """ return DeepDict.get_from_path(self.data, path, self.sep) def __setitem__(self, path, value=None): """ Overloads the setter for the [] operator for data assignment. :param path: [str] or [list(str)], the path to the target data structure level/content :param value: [object] rhs assignment object """ if isinstance(path, str): path = path.split(self.sep) if not isinstance(path, list) or isinstance(path, tuple): raise IOError("Input Error, expect list[str] type for path") if len(path) < 1: raise IOError("Input Error, missing section strings") if not path[0] in self._data.keys(): if value is not None and len(path) == 1: self._data[path[0]] = value else: self._data[path[0]] = {} tmp = self._data[path[0]] path.pop(0) while True: if len(path) == 0: break if path[0] not in tmp.keys(): if value is not None and len(path) == 1: tmp[path[0]] = value else: tmp[path[0]] = {} tmp = tmp[path[0]] else: tmp = tmp[path[0]] path.pop(0) def __len__(self): return len(self._data) def clear(self): """ clears the instance data """ LOG.debug("clear()") self._data = {} def from_file(self, fname): """ Loads data from file. Currently implemented .json and .xml file reader. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if fname.endswith(".json"): self.read_json(fname) elif fname.endswith(".xml"): self.read_xml(fname) else: LOG.error("Unknown filetype, expect [.json, .xml]") raise NotImplementedError("Unknown filetype, expect [.json, .xml]") def read_json(self, fname): """ Read json file :param fname: [str] input filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if not os.path.isfile(fname): raise IOError(f"File {fname} not found!") LOG.debug(f"read_json({fname})") try: with open(fname, "r") as read_file: self._data = json.load(read_file) DeepDict.value_traverse(self.data, callback=DeepDict.parse_type) except Exception as e: LOG.error(f"Error while reading json file {fname} or while converting types") raise IOError("Error while reading json file {fname} or while converting types") def read_xml(self, fname): """ Read xml file :param fname: [str] input filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if not os.path.isfile(fname): raise IOError(f"File {fname} not found!") LOG.debug(f"read_xml({fname})") try: with open(fname, "r") as read_file: xml = "".join(read_file.readlines()) self._data = xmltodict.parse(xml, attr_prefix='') DeepDict.value_traverse(self.data, callback=DeepDict.parse_type) except Exception as e: - LOG.error(f"Error while reading xml file {fname} or while converting types") - raise IOError("Error while reading json file {fname} or while converting types") + msg = f"Error while reading xml file {fname} or while converting types" + LOG.error(msg) + raise IOError(msg) # if written with DeepDict, the xml contains a root node called # deepdict which should beremoved for consistency reasons if DEEPDICT_XML_ROOT in self._data.keys(): self._data = self._data[DEEPDICT_XML_ROOT] self._data = dict(self.data) # convert the orderes dict structure to a default dict for consistency reasons convert_ordered2std_dict(self.data) def to_file(self, fname): """ Write to file, type is determined by checking the filename ending. Currently implemented is writing to json and to xml. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") if fname.endswith(".json"): self.write_json(fname) elif fname.endswith(".xml"): self.write_xml(fname) else: LOG.error(f"Unknown filetype, expect [.json, .xml]") raise NotImplementedError("Unknown filetype, expect [.json, .xml]") def write_json(self, fname): """ Dump data to json file. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") check_dir_existance(os.path.dirname(fname)) try: LOG.debug(f"write_json({fname})") with open(fname, "w") as write_file: json.dump(self.data, write_file) except Exception as e: LOG.error(f"Failed dumping to json file: {fname}") raise e def write_xml(self, fname): """ Dump data to json file. :param fname: [str] filename """ if not isinstance(fname, str): raise IOError("Input Error, expect str type for fname") check_dir_existance(os.path.dirname(fname)) xml = dicttoxml(self.data, custom_root=DEEPDICT_XML_ROOT, attr_type=False) LOG.debug(f"write_xml({fname})") try: with open(fname, "w") as write_file: write_file.write(xml.decode("utf-8")) except Exception as e: LOG.error(f"Failed dumping to xml file: {fname}") raise e def has_section(self, section): return DeepDict.has_key(self.data, section) @staticmethod def get_from_path(data, path, sep="/"): """ Implements a nested dict access via a path like string like so path='target/section/path' which is equivalent to my_dict['target']['section']['path']. :param data: [dict] input dictionary :param path: [str] pathlike string :param sep: [str] path separator, default='/' :return: [object] """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for data") raise IOError("Input Error, expect dict type for data") if isinstance(path, str): path = path.split(sep) if not isinstance(path, list) or isinstance(path, tuple): LOG.error(f"Input Error, expect list[str] type for path: {path}") raise IOError("Input Error, expect list[str] type for path") if not DeepDict.has_key(data, path[-1]): LOG.error(f"Input Error, section {path[-1]} does not exist in dictionary") raise IOError(f"Input Error, section {path[-1]} does not exist in dictionary") try: for k in path: data = data[k] except Exception as e: LOG.error(f"Failed retrieving data from path {path} due to {e}") raise LookupError(f"Failed retrieving data from path {path} due to {e}") return data @staticmethod def has_key(data, section, already_found=False): """ Checks if input dictionary has a key called section. The already_found parameter is for internal recursion checks. :param data: [dict] input dictionary :param section: [str] key string to search for :param already_found: recursion criteria check :return: [bool] section found """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for obj") raise IOError("Input Error, expect dict type for obj") if not isinstance(section, str): LOG.error(f"Input Error, expect dict type for obj {section}") raise IOError(f"Input Error, expect dict type for obj {section}") if already_found: return True found = False for key, value in data.items(): if key == section: found = True if isinstance(value, dict): found = DeepDict.has_key(data[key], section, found) return found @staticmethod def value_traverse(data, callback=None): """ Dictionary filter function, walks through the input dict (obj) calling the callback function for each value. The callback function return is assigned the the corresponding dict value. :param data: [dict] input dictionary :param callback: """ if not isinstance(data, dict): LOG.error("Input Error, expect dict type for obj") raise IOError("Input Error, expect dict type for obj") if not isinstance(callback, types.FunctionType): LOG.error("Input Error, expect function type for callback") raise IOError("Input Error, expect function type for callback") for key, value in data.items(): if isinstance(value, dict): DeepDict.value_traverse(data[key], callback) else: data[key] = callback(value) @staticmethod def parse_type(string): """ Type convert input string to float, int, list, tuple or string :param string: [str] input string :return: [T] converted output """ try: a = float(string) try: b = int(string) except ValueError: return float(string) if a == b: return b return a except ValueError: if string.startswith("[") and string.endswith("]"): string = re.sub(' ', '', string) elements = string[1:-1].split(",") li = [] for e in elements: li.append(DeepDict.parse_type(e)) return li elif string.startswith("(") and string.endswith(")"): elements = string[1:-1].split(",") li = [] for e in elements: li.append(DeepDict.parse_type(e)) return tuple(li) return string @property def data(self): return self._data @data.setter def data(self, value): if not isinstance(value, dict): LOG.error(f"Input Error, expect dict type for value, but got {type(value)}") raise IOError(f"Input Error, expect dict type for value, but got {type(value)}") self.clear() self._data = value @property def sep(self): return self._sep @sep.setter def sep(self, value): if not isinstance(value, str): LOG.error(f"Input Error, expect str type for value, but got {type(value)}") raise IOError(f"Input Error, expect str type for value, but got {type(value)}") self._sep = value diff --git a/hyppopy/globals.py b/hyppopy/globals.py index a694c1f..fd6f33f 100644 --- a/hyppopy/globals.py +++ b/hyppopy/globals.py @@ -1,31 +1,32 @@ # DKFZ # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # -*- coding: utf-8 -*- import os import sys import logging ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT) PLUGIN_DEFAULT_DIR = os.path.join(ROOT, *("hyppopy", "plugins")) TESTDATA_DIR = os.path.join(ROOT, *("hyppopy", "tests", "data")) SETTINGSPATH = "settings/solver" +CUSTOMPATH = "settings/custom" DEEPDICT_XML_ROOT = "hyppopy" LOGFILENAME = os.path.join(ROOT, 'logfile.log') DEBUGLEVEL = logging.DEBUG logging.basicConfig(filename=LOGFILENAME, filemode='w', format='%(levelname)s: %(name)s - %(message)s') diff --git a/hyppopy/plugins/hyperopt_solver_plugin.py b/hyppopy/plugins/hyperopt_solver_plugin.py index 6d82017..b21bc2e 100644 --- a/hyppopy/plugins/hyperopt_solver_plugin.py +++ b/hyppopy/plugins/hyperopt_solver_plugin.py @@ -1,65 +1,67 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from pprint import pformat from hyperopt import fmin, tpe, hp, STATUS_OK, STATUS_FAIL, Trials from yapsy.IPlugin import IPlugin from hyppopy.solverpluginbase import SolverPluginBase class hyperopt_Solver(SolverPluginBase, IPlugin): trials = None best = None def __init__(self): SolverPluginBase.__init__(self) LOG.debug("initialized") def loss_function(self, params): try: loss = self.loss(self.data, params) status = STATUS_OK except Exception as e: + LOG.error(f"execution of self.loss(self.data, params) failed due to:\n {e}") status = STATUS_FAIL return {'loss': loss, 'status': status} def execute_solver(self, parameter): LOG.debug(f"execute_solver using solution space:\n\n\t{pformat(parameter)}\n") self.trials = Trials() try: self.best = fmin(fn=self.loss_function, space=parameter, algo=tpe.suggest, - max_evals=self.max_iterations, + max_evals=self.settings.max_iterations, trials=self.trials) except Exception as e: - LOG.error(f"internal error in hyperopt.fmin occured. {e}") - raise BrokenPipeError(f"internal error in hyperopt.fmin occured. {e}") + msg = f"internal error in hyperopt.fmin occured. {e}" + LOG.error(msg) + raise BrokenPipeError(msg) def convert_results(self): solution = dict([(k, v) for k, v in self.best.items() if v is not None]) print('Solution\n========') print("\n".join(map(lambda x: "%s \t %s" % (x[0], str(x[1])), solution.items()))) diff --git a/hyppopy/plugins/optunity_solver_plugin.py b/hyppopy/plugins/optunity_solver_plugin.py index 7d1c6d0..53c0351 100644 --- a/hyppopy/plugins/optunity_solver_plugin.py +++ b/hyppopy/plugins/optunity_solver_plugin.py @@ -1,68 +1,68 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from pprint import pformat import optunity from yapsy.IPlugin import IPlugin from hyppopy.solverpluginbase import SolverPluginBase class optunity_Solver(SolverPluginBase, IPlugin): solver_info = None trials = None best = None status = None def __init__(self): SolverPluginBase.__init__(self) LOG.debug("initialized") def loss_function(self, **params): try: loss = self.loss(self.data, params) self.status.append('ok') return loss except Exception as e: self.status.append('fail') return 1e9 def execute_solver(self, parameter): LOG.debug(f"execute_solver using solution space:\n\n\t{pformat(parameter)}\n") self.status = [] try: self.best, self.trials, self.solver_info = optunity.minimize_structured(f=self.loss_function, - num_evals=self.max_iterations, + num_evals=self.settings.max_iterations, search_space=parameter) except Exception as e: LOG.error(f"internal error in optunity.minimize_structured occured. {e}") raise BrokenPipeError(f"internal error in optunity.minimize_structured occured. {e}") def convert_results(self): solution = dict([(k, v) for k, v in self.best.items() if v is not None]) print('Solution\n========') print("\n".join(map(lambda x: "%s \t %s" % (x[0], str(x[1])), solution.items()))) print(f"Solver used: {self.solver_info['solver_name']}") print(f"Optimum: {self.trials.optimum}") print(f"Iterations used: {self.trials.stats['num_evals']}") print(f"Duration: {self.trials.stats['time']} s") diff --git a/hyppopy/settingspluginbase.py b/hyppopy/settingspluginbase.py index bff67e6..31ee767 100644 --- a/hyppopy/settingspluginbase.py +++ b/hyppopy/settingspluginbase.py @@ -1,84 +1,87 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import abc import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) -from hyppopy.globals import SETTINGSPATH +from hyppopy.globals import SETTINGSPATH, CUSTOMPATH from hyppopy.deepdict import DeepDict class SettingsPluginBase(object): _data = None _name = None def __init__(self): self._data = DeepDict() @abc.abstractmethod def convert_parameter(self): raise NotImplementedError('users must define convert_parameter to use this base class') def get_hyperparameter(self): return self.convert_parameter(self.data["hyperparameter"]) def set(self, data): self.data.clear() self.data = data def read(self, fname): self.data.clear() self.data.from_file(fname) def write(self, fname): self.data.to_file(fname) def set_attributes(self, cls): attrs_sec = self.data[SETTINGSPATH] for key, value in attrs_sec.items(): setattr(cls, key, value) + attrs_sec = self.data[CUSTOMPATH] + for key, value in attrs_sec.items(): + setattr(cls, key, value) @property def data(self): return self._data @data.setter def data(self, value): if isinstance(value, dict): self._data.data = value elif isinstance(value, DeepDict): self._data = value else: raise IOError(f"unexpected input type({type(value)}) for data, needs to be of type dict or DeepDict!") @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead") raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead") self._name = value diff --git a/hyppopy/solver.py b/hyppopy/solver.py index 8a93e5e..e2a9a5e 100644 --- a/hyppopy/solver.py +++ b/hyppopy/solver.py @@ -1,83 +1,84 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class Solver(object): _name = None _solver_plugin = None _settings_plugin = None def __init__(self): pass def set_data(self, data): self.solver.set_data(data) def set_parameters(self, params): self.settings.set(params) self.settings.set_attributes(self.solver) def read_parameter(self, fname): self.settings.read(fname) - self.settings.set_attributes(self.solver) + self.settings.set_attributes(self.settings) def set_loss_function(self, loss_func): self.solver.set_loss_function(loss_func) def run(self): - self.solver.run(self.settings.get_hyperparameter()) + self.solver.settings = self.settings + self.solver.run() def get_results(self): self.solver.get_results() @property def is_ready(self): return self.solver is not None and self.settings is not None @property def solver(self): return self._solver_plugin @solver.setter def solver(self, value): self._solver_plugin = value @property def settings(self): return self._settings_plugin @settings.setter def settings(self, value): self._settings_plugin = value @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead") raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead") self._name = value diff --git a/hyppopy/solverpluginbase.py b/hyppopy/solverpluginbase.py index aa14611..91ce6ba 100644 --- a/hyppopy/solverpluginbase.py +++ b/hyppopy/solverpluginbase.py @@ -1,69 +1,86 @@ # -*- coding: utf-8 -*- # # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import abc import os import logging from hyppopy.globals import DEBUGLEVEL +from hyppopy.settingspluginbase import SettingsPluginBase + LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class SolverPluginBase(object): data = None loss = None + _settings = None _name = None def __init__(self): pass @abc.abstractmethod def loss_function(self, params): raise NotImplementedError('users must define loss_func to use this base class') @abc.abstractmethod def execute_solver(self): raise NotImplementedError('users must define execute_solver to use this base class') @abc.abstractmethod def convert_results(self): raise NotImplementedError('users must define convert_results to use this base class') def set_data(self, data): self.data = data def set_loss_function(self, func): self.loss = func def get_results(self): self.convert_results() - def run(self, parameter): - self.execute_solver(parameter) + def run(self): + self.execute_solver(self.settings.get_hyperparameter()) @property def name(self): return self._name @name.setter def name(self, value): if not isinstance(value, str): - LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead") - raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead") + msg = f"Invalid input, str type expected for value, got {type(value)} instead" + LOG.error(msg) + raise IOError(msg) self._name = value + @property + def settings(self): + return self._settings + + @settings.setter + def settings(self, value): + if not isinstance(value, SettingsPluginBase): + msg = f"Invalid input, SettingsPluginBase type expected for value, got {type(value)} instead" + LOG.error(msg) + raise IOError(msg) + self._settings = value + + diff --git a/hyppopy/workflows/__init__.py b/hyppopy/workflows/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hyppopy/workflows/randomforest_usecase/__init__.py b/hyppopy/workflows/randomforest_usecase/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py b/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py new file mode 100644 index 0000000..1fb55db --- /dev/null +++ b/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# +# DKFZ +# +# +# Copyright (c) German Cancer Research Center, +# Division of Medical and Biological Informatics. +# All rights reserved. +# +# This software is distributed WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. +# +# See LICENSE.txt or http://www.mitk.org for details. +# +# Author: Sven Wanner (s.wanner@dkfz.de) + +import os +import numpy as np +import pandas as pd +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import cross_val_score + +import hyppopy.solverfactory as sfac + + +def data_loader(path, data_name, labels_name): + if data_name.endswith(".npy"): + if not labels_name.endswith(".npy"): + raise IOError("Expect both data_name and labels_name being of type .npy!") + data = [np.load(os.path.join(path, data_name)), np.load(os.path.join(path, labels_name))] + elif data_name.endswith(".csv"): + try: + dataset = pd.read_csv(os.path.join(path, data_name)) + y = dataset[labels_name].values + X = dataset.drop([labels_name], axis=1).values + data = [X, y] + except Exception as e: + print("Precondition violation, this usage case expects as data_name a " + "csv file and as label_name a name of a column in this csv table!") + else: + raise NotImplementedError("This combination of data_name and labels_name " + "does not yet exist, feel free to add it") + return data + + +def randomforest_usecase(args): + print("Execute Random Forest UseCase...") + + factory = sfac.SolverFactory.instance() + solver = factory.get_solver(args.plugin) + solver.read_parameter(args.config) + + data = data_loader(args.data, solver.settings.data_name, solver.settings.labels_name) + solver.set_data(data) + + def rf_loss(data, params): + clf = RandomForestClassifier(**params) + return -cross_val_score(estimator=clf, X=data[0], y=data[1], cv=3).mean() + + solver.set_loss_function(rf_loss) + solver.run() + solver.get_results() diff --git a/hyppopy/workflows/svc_usecase/__init__.py b/hyppopy/workflows/svc_usecase/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hyppopy/workflows/svc_usecase/svc_usecase.py b/hyppopy/workflows/svc_usecase/svc_usecase.py new file mode 100644 index 0000000..45ab10c --- /dev/null +++ b/hyppopy/workflows/svc_usecase/svc_usecase.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# +# DKFZ +# +# +# Copyright (c) German Cancer Research Center, +# Division of Medical and Biological Informatics. +# All rights reserved. +# +# This software is distributed WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. +# +# See LICENSE.txt or http://www.mitk.org for details. +# +# Author: Sven Wanner (s.wanner@dkfz.de) + +import os +import numpy as np +import pandas as pd +from sklearn.svm import SVC +from sklearn.model_selection import cross_val_score + +import hyppopy.solverfactory as sfac + + +def data_loader(path, data_name, labels_name): + if data_name.endswith(".npy"): + if not labels_name.endswith(".npy"): + raise IOError("Expect both data_name and labels_name being of type .npy!") + data = [np.load(os.path.join(path, data_name)), np.load(os.path.join(path, labels_name))] + elif data_name.endswith(".csv"): + try: + dataset = pd.read_csv(os.path.join(path, data_name)) + y = dataset[labels_name].values + X = dataset.drop([labels_name], axis=1).values + data = [X, y] + except Exception as e: + print("Precondition violation, this usage case expects as data_name a " + "csv file and as label_name a name of a column in this csv table!") + else: + raise NotImplementedError("This combination of data_name and labels_name " + "does not yet exist, feel free to add it") + return data + + +def svc_usecase(args): + print("Execute SVC UseCase...") + + factory = sfac.SolverFactory.instance() + solver = factory.get_solver(args.plugin) + solver.read_parameter(args.config) + + data = data_loader(args.data, solver.settings.data_name, solver.settings.labels_name) + solver.set_data(data) + + def svc_loss(data, params): + clf = SVC(**params) + return -cross_val_score(estimator=clf, X=data[0], y=data[1], cv=3).mean() + + solver.set_loss_function(svc_loss) + solver.run() + solver.get_results() diff --git a/hyppopy/workflows/unet_usecase/__init__.py b/hyppopy/workflows/unet_usecase/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hyppopy/workflows/unet_usecase/unet_uscase_utils.py b/hyppopy/workflows/unet_usecase/unet_uscase_utils.py new file mode 100644 index 0000000..ae1f4ce --- /dev/null +++ b/hyppopy/workflows/unet_usecase/unet_uscase_utils.py @@ -0,0 +1,419 @@ +# -*- coding: utf-8 -*- +# +# DKFZ +# +# +# Copyright (c) German Cancer Research Center, +# Division of Medical and Biological Informatics. +# All rights reserved. +# +# This software is distributed WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. +# +# See LICENSE.txt or http://www.mitk.org for details. +# +# Author: Sven Wanner (s.wanner@dkfz.de) + +import os +import torch +import pickle +import fnmatch +import numpy as np +from torch import nn +from medpy.io import load +from collections import defaultdict +from abc import ABCMeta, abstractmethod + + +def sum_tensor(input, axes, keepdim=False): + axes = np.unique(axes).astype(int) + if keepdim: + for ax in axes: + input = input.sum(int(ax), keepdim=True) + else: + for ax in sorted(axes, reverse=True): + input = input.sum(int(ax)) + return input + + +def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None): + if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]: + rebalance_weights = rebalance_weights[1:] # this is the case when use_bg=False + axes = tuple([0] + list(range(2, len(net_output.size())))) + tp = sum_tensor(net_output * gt, axes, keepdim=False) + fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False) + fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False) + weights = torch.ones(tp.shape) + weights[0] = background_weight + if net_output.device.type == "cuda": + weights = weights.cuda(net_output.device.index) + if rebalance_weights is not None: + rebalance_weights = torch.from_numpy(rebalance_weights).float() + if net_output.device.type == "cuda": + rebalance_weights = rebalance_weights.cuda(net_output.device.index) + tp = tp * rebalance_weights + fn = fn * rebalance_weights + result = (- ((2 * tp + smooth_in_nom) / (2 * tp + fp + fn + smooth)) * weights).mean() + return result + + +def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1.): + axes = tuple(range(2, len(net_output.size()))) + intersect = sum_tensor(net_output * gt, axes, keepdim=False) + denom = sum_tensor(net_output + gt, axes, keepdim=False) + result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth)) * weights).mean() #TODO: Was ist weights and er Stelle? + return result + + +class SoftDiceLoss(nn.Module): + def __init__(self, smooth=1., apply_nonlin=None, batch_dice=False, do_bg=True, smooth_in_nom=True, background_weight=1, rebalance_weights=None): + """ + hahaa no documentation for you today + :param smooth: + :param apply_nonlin: + :param batch_dice: + :param do_bg: + :param smooth_in_nom: + :param background_weight: + :param rebalance_weights: + """ + super(SoftDiceLoss, self).__init__() + if not do_bg: + assert background_weight == 1, "if there is no bg, then set background weight to 1 you dummy" + self.rebalance_weights = rebalance_weights + self.background_weight = background_weight + self.smooth_in_nom = smooth_in_nom + self.do_bg = do_bg + self.batch_dice = batch_dice + self.apply_nonlin = apply_nonlin + self.smooth = smooth + self.y_onehot = None + if not smooth_in_nom: + self.nom_smooth = 0 + else: + self.nom_smooth = smooth + + def forward(self, x, y): + with torch.no_grad(): + y = y.long() + shp_x = x.shape + shp_y = y.shape + if self.apply_nonlin is not None: + x = self.apply_nonlin(x) + if len(shp_x) != len(shp_y): + y = y.view((shp_y[0], 1, *shp_y[1:])) + # now x and y should have shape (B, C, X, Y(, Z))) and (B, 1, X, Y(, Z))), respectively + y_onehot = torch.zeros(shp_x) + if x.device.type == "cuda": + y_onehot = y_onehot.cuda(x.device.index) + y_onehot.scatter_(1, y, 1) + if not self.do_bg: + x = x[:, 1:] + y_onehot = y_onehot[:, 1:] + if not self.batch_dice: + if self.background_weight != 1 or (self.rebalance_weights is not None): + raise NotImplementedError("nah son") + l = soft_dice(x, y_onehot, self.smooth, self.smooth_in_nom) + else: + l = soft_dice_per_batch_2(x, y_onehot, self.smooth, self.smooth_in_nom, + background_weight=self.background_weight, + rebalance_weights=self.rebalance_weights) + return l + + +def load_dataset(base_dir, pattern='*.npy', slice_offset=5, keys=None): + fls = [] + files_len = [] + slices_ax = [] + + for root, dirs, files in os.walk(base_dir): + i = 0 + for filename in sorted(fnmatch.filter(files, pattern)): + + if keys is not None and filename[:-4] in keys: + npy_file = os.path.join(root, filename) + numpy_array = np.load(npy_file, mmap_mode="r") + + fls.append(npy_file) + files_len.append(numpy_array.shape[1]) + + slices_ax.extend([(i, j) for j in range(slice_offset, files_len[-1] - slice_offset)]) + + i += 1 + + return fls, files_len, slices_ax, + + +class SlimDataLoaderBase(object): + def __init__(self, data, batch_size, number_of_threads_in_multithreaded=None): + """ + Slim version of DataLoaderBase (which is now deprecated). Only provides very simple functionality. + You must derive from this class to implement your own DataLoader. You must overrive self.generate_train_batch() + If you use our MultiThreadedAugmenter you will need to also set and use number_of_threads_in_multithreaded. See + multithreaded_dataloading in examples! + :param data: will be stored in self._data. You can use it to generate your batches in self.generate_train_batch() + :param batch_size: will be stored in self.batch_size for use in self.generate_train_batch() + :param number_of_threads_in_multithreaded: will be stored in self.number_of_threads_in_multithreaded. + None per default. If you wish to iterate over all your training data only once per epoch, you must coordinate + your Dataloaders and you will need this information + """ + __metaclass__ = ABCMeta + self.number_of_threads_in_multithreaded = number_of_threads_in_multithreaded + self._data = data + self.batch_size = batch_size + self.thread_id = 0 + + def set_thread_id(self, thread_id): + self.thread_id = thread_id + + def __iter__(self): + return self + + def __next__(self): + return self.generate_train_batch() + + @abstractmethod + def generate_train_batch(self): + '''override this + Generate your batch from self._data .Make sure you generate the correct batch size (self.BATCH_SIZE) + ''' + pass + + +class NumpyDataLoader(SlimDataLoaderBase): + def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, + seed=None, file_pattern='*.npy', label_slice=1, input_slice=(0,), keys=None): + + self.files, self.file_len, self.slices = load_dataset(base_dir=base_dir, pattern=file_pattern, slice_offset=0, keys=keys, ) + super(NumpyDataLoader, self).__init__(self.slices, batch_size, num_batches) + + self.batch_size = batch_size + + self.use_next = False + if mode == "train": + self.use_next = False + + self.slice_idxs = list(range(0, len(self.slices))) + + self.data_len = len(self.slices) + + self.num_batches = min((self.data_len // self.batch_size)+10, num_batches) + + if isinstance(label_slice, int): + label_slice = (label_slice,) + self.input_slice = input_slice + self.label_slice = label_slice + + self.np_data = np.asarray(self.slices) + + def reshuffle(self): + print("Reshuffle...") + random.shuffle(self.slice_idxs) + print("Initializing... this might take a while...") + + def generate_train_batch(self): + open_arr = random.sample(self._data, self.batch_size) + return self.get_data_from_array(open_arr) + + def __len__(self): + n_items = min(self.data_len // self.batch_size, self.num_batches) + return n_items + + def __getitem__(self, item): + slice_idxs = self.slice_idxs + data_len = len(self.slices) + np_data = self.np_data + + if item > len(self): + raise StopIteration() + if (item * self.batch_size) == data_len: + raise StopIteration() + + start_idx = (item * self.batch_size) % data_len + stop_idx = ((item + 1) * self.batch_size) % data_len + + if ((item + 1) * self.batch_size) == data_len: + stop_idx = data_len + + if stop_idx > start_idx: + idxs = slice_idxs[start_idx:stop_idx] + else: + raise StopIteration() + + open_arr = np_data[idxs] + + return self.get_data_from_array(open_arr) + + def get_data_from_array(self, open_array): + data = [] + fnames = [] + slice_idxs = [] + labels = [] + + for slice in open_array: + fn_name = self.files[slice[0]] + + numpy_array = np.load(fn_name, mmap_mode="r") + + numpy_slice = numpy_array[ :, slice[1], ] + data.append(numpy_slice[None, self.input_slice[0]]) # 'None' keeps the dimension + + if self.label_slice is not None: + labels.append(numpy_slice[None, self.label_slice[0]]) # 'None' keeps the dimension + + fnames.append(self.files[slice[0]]) + slice_idxs.append(slice[1]) + + ret_dict = {'data': np.asarray(data), 'fnames': fnames, 'slice_idxs': slice_idxs} + if self.label_slice is not None: + ret_dict['seg'] = np.asarray(labels) + + return ret_dict + + +class NumpyDataSet(object): + """ + TODO + """ + def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, num_processes=8, num_cached_per_queue=8 * 4, target_size=128, + file_pattern='*.npy', label_slice=1, input_slice=(0,), do_reshuffle=True, keys=None): + + data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, seed=seed, file_pattern=file_pattern, + input_slice=input_slice, label_slice=label_slice, keys=keys) + + self.data_loader = data_loader + self.batch_size = batch_size + self.do_reshuffle = do_reshuffle + self.number_of_slices = 1 + + self.transforms = get_transforms(mode=mode, target_size=target_size) + self.augmenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=num_processes, + num_cached_per_queue=num_cached_per_queue, seeds=seed, + shuffle=do_reshuffle) + self.augmenter.restart() + + def __len__(self): + return len(self.data_loader) + + def __iter__(self): + if self.do_reshuffle: + self.data_loader.reshuffle() + self.augmenter.renew() + return self.augmenter + + def __next__(self): + return next(self.augmenter) + + + +def reshape(orig_img, append_value=-1024, new_shape=(512, 512, 512)): + reshaped_image = np.zeros(new_shape) + reshaped_image[...] = append_value + x_offset = 0 + y_offset = 0 # (new_shape[1] - orig_img.shape[1]) // 2 + z_offset = 0 # (new_shape[2] - orig_img.shape[2]) // 2 + + reshaped_image[x_offset:orig_img.shape[0] + x_offset, y_offset:orig_img.shape[1] + y_offset, + z_offset:orig_img.shape[2] + z_offset] = orig_img + # insert temp_img.min() as background value + + return reshaped_image + + +def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): + if join: + l = os.path.join + else: + l = lambda x, y: y + res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) + and (prefix is None or i.startswith(prefix)) + and (suffix is None or i.endswith(suffix))] + if sort: + res.sort() + return res + + +def preprocess_data(root_dir): + print("preprocess data...") + image_dir = os.path.join(root_dir, 'imagesTr') + print(f"image_dir: {image_dir}") + label_dir = os.path.join(root_dir, 'labelsTr') + print(f"label_dir: {label_dir}") + output_dir = os.path.join(root_dir, 'preprocessed') + print(f"output_dir: {output_dir} ... ", end="") + classes = 3 + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("created!") + else: + print("found!\npreprocessed data already available, aborted preprocessing!") + return False + + print("start preprocessing ... ", end="") + class_stats = defaultdict(int) + total = 0 + + nii_files = subfiles(image_dir, suffix=".nii.gz", join=False) + + for i in range(0, len(nii_files)): + if nii_files[i].startswith("._"): + nii_files[i] = nii_files[i][2:] + + for i, f in enumerate(nii_files): + image, _ = load(os.path.join(image_dir, f)) + label, _ = load(os.path.join(label_dir, f.replace('_0000', ''))) + + for i in range(classes): + class_stats[i] += np.sum(label == i) + total += np.sum(label == i) + + image = (image - image.min()) / (image.max() - image.min()) + + image = reshape(image, append_value=0, new_shape=(64, 64, 64)) + label = reshape(label, append_value=0, new_shape=(64, 64, 64)) + + result = np.stack((image, label)) + + np.save(os.path.join(output_dir, f.split('.')[0] + '.npy'), result) + print("finished!") + return True + + +def create_splits(output_dir, image_dir): + print("creating splits ... ", end="") + npy_files = subfiles(image_dir, suffix=".npy", join=False) + + trainset_size = len(npy_files) * 50 // 100 + valset_size = len(npy_files) * 25 // 100 + testset_size = len(npy_files) * 25 // 100 + + splits = [] + for split in range(0, 5): + image_list = npy_files.copy() + trainset = [] + valset = [] + testset = [] + for i in range(0, trainset_size): + patient = np.random.choice(image_list) + image_list.remove(patient) + trainset.append(patient[:-4]) + for i in range(0, valset_size): + patient = np.random.choice(image_list) + image_list.remove(patient) + valset.append(patient[:-4]) + for i in range(0, testset_size): + patient = np.random.choice(image_list) + image_list.remove(patient) + testset.append(patient[:-4]) + split_dict = dict() + split_dict['train'] = trainset + split_dict['val'] = valset + split_dict['test'] = testset + + splits.append(split_dict) + + with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f: + pickle.dump(splits, f) + print("finished!") diff --git a/hyppopy/workflows/unet_usecase/unet_usecase.py b/hyppopy/workflows/unet_usecase/unet_usecase.py new file mode 100644 index 0000000..550400d --- /dev/null +++ b/hyppopy/workflows/unet_usecase/unet_usecase.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# +# DKFZ +# +# +# Copyright (c) German Cancer Research Center, +# Division of Medical and Biological Informatics. +# All rights reserved. +# +# This software is distributed WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. +# +# See LICENSE.txt or http://www.mitk.org for details. +# +# Author: Sven Wanner (s.wanner@dkfz.de) + +import os +import pickle + +import torch +import torch.optim as optim +from torch.optim.lr_scheduler import ReduceLROnPlateau +import torch.nn.functional as F +from networks.RecursiveUNet import UNet + +import hyppopy.solverfactory as sfac +from .unet_uscase_utils import * + + +def unet_usecase(args): + print("Execute UNet UseCase...") + data_dir = args.data + preprocessed_dir = os.path.join(args.data, 'preprocessed') + solver_plugin = args.plugin + config_file = args.config + print(f"input data directory: {data_dir}") + print(f"use plugin: {solver_plugin}") + print(f"config file: {config_file}") + + factory = sfac.SolverFactory.instance() + solver = factory.get_solver(solver_plugin) + solver.read_parameter(config_file) + + if preprocess_data(data_dir): + create_splits(output_dir=data_dir, image_dir=preprocessed_dir) + + with open(os.path.join(data_dir, "splits.pkl"), 'rb') as f: + splits = pickle.load(f) + + tr_keys = splits[solver.settings.fold]['train'] + val_keys = splits[solver.settings.fold]['val'] + test_keys = splits[solver.settings.fold]['test'] + + def loss_function(patch_size, batch_size): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + train_data_loader = NumpyDataSet(data_dir, + target_size=patch_size, + batch_size=batch_size, + keys=tr_keys) + val_data_loader = NumpyDataSet(data_dir, + target_size=patch_size, + batch_size=batch_size, + mode="val", + do_reshuffle=False) + model = UNet(num_classes=solver.settings.num_classes, in_channels=solver.settings.in_channels) + model.to(device) + + dice_loss = SoftDiceLoss(batch_dice=True) + ce_loss = torch.nn.CrossEntropyLoss() + node_optimizer = optim.Adam(model.parameters(), lr=solver.settings.learning_rate) + scheduler = ReduceLROnPlateau(node_optimizer, 'min') + + model.train() + + data = None + batch_counter = 0 + for data_batch in train_data_loader: + + node_optimizer.zero_grad() + + data = data_batch['data'][0].float().to(device) + target = data_batch['seg'][0].long().to(device) + + pred = model(data) + pred_softmax = F.softmax(pred, dim=1) + + loss = dice_loss(pred_softmax, target.squeeze()) + ce_loss(pred, target.squeeze()) + loss.backward() + node_optimizer.step() + batch_counter += 1 + + assert data is not None, 'data is None. Please check if your dataloader works properly' + + model.eval() + + data = None + loss_list = [] + + with torch.no_grad(): + for data_batch in val_data_loader: + data = data_batch['data'][0].float().to(device) + target = data_batch['seg'][0].long().to(device) + + pred = model(data) + pred_softmax = F.softmax(pred) + + loss = dice_loss(pred_softmax, target.squeeze()) + ce_loss(pred, target.squeeze()) + loss_list.append(loss.item()) + + assert data is not None, 'data is None. Please check if your dataloader works properly' + scheduler.step(np.mean(loss_list)) + + data = [] + + # solver.set_data(data) + # solver.read_parameter(config_file) + # solver.set_loss_function(loss_function) + # solver.run() + # solver.get_results() + + + + + + diff --git a/setup.py b/setup.py index 40ba7c1..9221d46 100644 --- a/setup.py +++ b/setup.py @@ -1,37 +1,39 @@ # -*- coding: utf-8 -*- from setuptools import setup, find_packages with open('README.rst') as f: readme = f.read() with open('LICENSE') as f: license = f.read() setup( name='hyppopy', version='0.0.1', description='Hyper-Parameter Optimization Toolbox for Blackboxfunction Optimization', long_description=readme, # if you want, put your own name here # (this would likely result in people sending you emails) author='Sven Wanner', author_email='s.wanner@dkfz.de', url='', license=license, packages=find_packages(exclude=('bin', '*test*', 'doc', 'hyppopy')), # the requirements to install this project. # Since this one is so simple this is empty. - install_requires=[], + install_requires=['dicttoxml>=1.7.4', 'hyperopt>=0.1.1', 'matplotlib>=3.0.2', 'numpy>=1.16.0', + 'Optunity>=1.1.1', 'pytest>=4.1.1', 'scikit-learn>=0.20.2', 'scipy>=1.2.0', 'sklearn>=0.0', 'Sphinx>=1.8.3', + 'xmlrunner>=1.7.7', 'xmltodict>=0.11.0', 'Yapsy>=1.11.223', 'visdom>=0.1.8.8'], # a more sophisticated project might have something like: #install_requires=['numpy>=1.11.0', 'scipy>=0.17', 'scikit-learn'] # after running setup.py, you will be able to call hypopy_exe # from the console as if it was a normal binary. It will call the function # main in bin/hypopy_exe.py entry_points={ 'console_scripts': ['hyppopy_exe=bin.hypopy_exe:main'], } )