diff --git a/hyppopy/deepdict.py b/hyppopy/deepdict.py
index 9a31c8a..a149797 100644
--- a/hyppopy/deepdict.py
+++ b/hyppopy/deepdict.py
@@ -1,386 +1,435 @@
# -*- coding: utf-8 -*-
#
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
import os
import re
import json
import types
import pprint
import xmltodict
from dicttoxml import dicttoxml
from collections import OrderedDict
import logging
LOG = logging.getLogger('hyppopy')
from hyppopy.globals import DEEPDICT_XML_ROOT
def convert_ordered2std_dict(obj):
"""
Helper function converting an OrderedDict into a standard lib dict.
:param obj: [OrderedDict]
"""
for key, value in obj.items():
if isinstance(value, OrderedDict):
obj[key] = dict(obj[key])
convert_ordered2std_dict(obj[key])
def check_dir_existance(dirname):
"""
Helper function to check if a directory exists, creating it if not.
:param dirname: [str] full path of the directory to check
"""
if not os.path.exists(dirname):
os.mkdir(dirname)
class DeepDict(object):
"""
The DeepDict class represents a nested dictionary with additional functionality compared to a standard
lib dict. The data can be accessed and changed vie a pathlike access and dumped or read to .json/.xml files.
Initializing instances using defaults creates an empty DeepDict. Using in_data enables to initialize the
object instance with data, where in_data can be a dict, or a filepath to a json or xml file. Using path sep
the appearance of path passing can be changed, a default data access via path would look like my_dd['target/section/path'] with path_sep='.' like so my_dd['target.section.path']
:param in_data: [dict] or [str], input dict or filename
:param path_sep: [str] path separator character
"""
_data = None
_sep = "/"
def __init__(self, in_data=None, path_sep="/"):
self.clear()
self._sep = path_sep
LOG.debug(f"path separator is: {self._sep}")
if in_data is not None:
if isinstance(in_data, str):
self.from_file(in_data)
elif isinstance(in_data, dict):
self.data = in_data
def __str__(self):
"""
Enables print output for class instances, printing the instance data dict using pretty print
:return: [str]
"""
return pprint.pformat(self.data)
def __eq__(self, other):
"""
Overloads the == operator comparing the instance data dictionaries for equality
:param other: [DeepDict] rhs
:return: [bool]
"""
return self.data == other.data
def __getitem__(self, path):
"""
Overloads the return of the [] operator for data access. This enables access the DeepDict instance like so:
my_dd['target/section/path'] or my_dd[['target','section','path']]
:param path: [str] or [list(str)], the path to the target data structure level/content
:return: [object]
"""
return DeepDict.get_from_path(self.data, path, self.sep)
def __setitem__(self, path, value=None):
"""
Overloads the setter for the [] operator for data assignment.
:param path: [str] or [list(str)], the path to the target data structure level/content
:param value: [object] rhs assignment object
"""
if isinstance(path, str):
path = path.split(self.sep)
if not isinstance(path, list) or isinstance(path, tuple):
raise IOError("Input Error, expect list[str] type for path")
if len(path) < 1:
raise IOError("Input Error, missing section strings")
if not path[0] in self._data.keys():
if value is not None and len(path) == 1:
self._data[path[0]] = value
else:
self._data[path[0]] = {}
tmp = self._data[path[0]]
path.pop(0)
while True:
if len(path) == 0:
break
if path[0] not in tmp.keys():
if value is not None and len(path) == 1:
tmp[path[0]] = value
else:
tmp[path[0]] = {}
tmp = tmp[path[0]]
else:
tmp = tmp[path[0]]
path.pop(0)
def __len__(self):
return len(self._data)
+ def items(self):
+ return self.data.items()
+
def clear(self):
"""
clears the instance data
"""
LOG.debug("clear()")
self._data = {}
def from_file(self, fname):
"""
Loads data from file. Currently implemented .json and .xml file reader.
:param fname: [str] filename
"""
if not isinstance(fname, str):
raise IOError("Input Error, expect str type for fname")
if fname.endswith(".json"):
self.read_json(fname)
elif fname.endswith(".xml"):
self.read_xml(fname)
else:
LOG.error("Unknown filetype, expect [.json, .xml]")
raise NotImplementedError("Unknown filetype, expect [.json, .xml]")
def read_json(self, fname):
"""
Read json file
:param fname: [str] input filename
"""
if not isinstance(fname, str):
raise IOError("Input Error, expect str type for fname")
if not os.path.isfile(fname):
raise IOError(f"File {fname} not found!")
LOG.debug(f"read_json({fname})")
try:
with open(fname, "r") as read_file:
self._data = json.load(read_file)
DeepDict.value_traverse(self.data, callback=DeepDict.parse_type)
except Exception as e:
LOG.error(f"Error while reading json file {fname} or while converting types")
raise IOError("Error while reading json file {fname} or while converting types")
def read_xml(self, fname):
"""
Read xml file
:param fname: [str] input filename
"""
if not isinstance(fname, str):
raise IOError("Input Error, expect str type for fname")
if not os.path.isfile(fname):
raise IOError(f"File {fname} not found!")
LOG.debug(f"read_xml({fname})")
try:
with open(fname, "r") as read_file:
xml = "".join(read_file.readlines())
self._data = xmltodict.parse(xml, attr_prefix='')
DeepDict.value_traverse(self.data, callback=DeepDict.parse_type)
except Exception as e:
msg = f"Error while reading xml file {fname} or while converting types"
LOG.error(msg)
raise IOError(msg)
# if written with DeepDict, the xml contains a root node called
# deepdict which should beremoved for consistency reasons
if DEEPDICT_XML_ROOT in self._data.keys():
self._data = self._data[DEEPDICT_XML_ROOT]
self._data = dict(self.data)
# convert the orderes dict structure to a default dict for consistency reasons
convert_ordered2std_dict(self.data)
def to_file(self, fname):
"""
Write to file, type is determined by checking the filename ending.
Currently implemented is writing to json and to xml.
:param fname: [str] filename
"""
if not isinstance(fname, str):
raise IOError("Input Error, expect str type for fname")
if fname.endswith(".json"):
self.write_json(fname)
elif fname.endswith(".xml"):
self.write_xml(fname)
else:
LOG.error(f"Unknown filetype, expect [.json, .xml]")
raise NotImplementedError("Unknown filetype, expect [.json, .xml]")
def write_json(self, fname):
"""
Dump data to json file.
:param fname: [str] filename
"""
if not isinstance(fname, str):
raise IOError("Input Error, expect str type for fname")
check_dir_existance(os.path.dirname(fname))
try:
LOG.debug(f"write_json({fname})")
with open(fname, "w") as write_file:
json.dump(self.data, write_file)
except Exception as e:
LOG.error(f"Failed dumping to json file: {fname}")
raise e
def write_xml(self, fname):
"""
Dump data to json file.
:param fname: [str] filename
"""
if not isinstance(fname, str):
raise IOError("Input Error, expect str type for fname")
check_dir_existance(os.path.dirname(fname))
xml = dicttoxml(self.data, custom_root=DEEPDICT_XML_ROOT, attr_type=False)
LOG.debug(f"write_xml({fname})")
try:
with open(fname, "w") as write_file:
write_file.write(xml.decode("utf-8"))
except Exception as e:
LOG.error(f"Failed dumping to xml file: {fname}")
raise e
def has_section(self, section):
return DeepDict.has_key(self.data, section)
@staticmethod
def get_from_path(data, path, sep="/"):
"""
Implements a nested dict access via a path like string like so path='target/section/path'
which is equivalent to my_dict['target']['section']['path'].
:param data: [dict] input dictionary
:param path: [str] pathlike string
:param sep: [str] path separator, default='/'
:return: [object]
"""
if not isinstance(data, dict):
LOG.error("Input Error, expect dict type for data")
raise IOError("Input Error, expect dict type for data")
if isinstance(path, str):
path = path.split(sep)
if not isinstance(path, list) or isinstance(path, tuple):
LOG.error(f"Input Error, expect list[str] type for path: {path}")
raise IOError("Input Error, expect list[str] type for path")
if not DeepDict.has_key(data, path[-1]):
LOG.error(f"Input Error, section {path[-1]} does not exist in dictionary")
raise IOError(f"Input Error, section {path[-1]} does not exist in dictionary")
try:
for k in path:
data = data[k]
except Exception as e:
LOG.error(f"Failed retrieving data from path {path} due to {e}")
raise LookupError(f"Failed retrieving data from path {path} due to {e}")
return data
@staticmethod
def has_key(data, section, already_found=False):
"""
Checks if input dictionary has a key called section. The already_found parameter
is for internal recursion checks.
:param data: [dict] input dictionary
:param section: [str] key string to search for
:param already_found: recursion criteria check
:return: [bool] section found
"""
if not isinstance(data, dict):
LOG.error("Input Error, expect dict type for obj")
raise IOError("Input Error, expect dict type for obj")
if not isinstance(section, str):
LOG.error(f"Input Error, expect dict type for obj {section}")
raise IOError(f"Input Error, expect dict type for obj {section}")
if already_found:
return True
found = False
for key, value in data.items():
if key == section:
found = True
if isinstance(value, dict):
found = DeepDict.has_key(data[key], section, found)
return found
@staticmethod
def value_traverse(data, callback=None):
"""
Dictionary filter function, walks through the input dict (obj) calling the callback function for each value.
The callback function return is assigned the the corresponding dict value.
:param data: [dict] input dictionary
:param callback:
"""
if not isinstance(data, dict):
LOG.error("Input Error, expect dict type for obj")
raise IOError("Input Error, expect dict type for obj")
if not isinstance(callback, types.FunctionType):
LOG.error("Input Error, expect function type for callback")
raise IOError("Input Error, expect function type for callback")
for key, value in data.items():
if isinstance(value, dict):
DeepDict.value_traverse(data[key], callback)
else:
data[key] = callback(value)
+ def transfer_attrs(self, cls, target_section):
+ def set(item):
+ setattr(cls, item[0], item[1])
+ DeepDict.sectionconstraint_item_traverse(self.data, target_section, callback=set, section=None)
+
+ @staticmethod
+ def sectionconstraint_item_traverse(data, target_section, callback=None, section=None):
+ """
+ Dictionary filter function, walks through the input dict (obj) calling the callback function for each item.
+ The callback function then is called with the key value pair as tuple input but only for the target section.
+ :param data: [dict] input dictionary
+ :param callback:
+ """
+ if not isinstance(data, dict):
+ LOG.error("Input Error, expect dict type for obj")
+ raise IOError("Input Error, expect dict type for obj")
+ if not isinstance(callback, types.FunctionType):
+ LOG.error("Input Error, expect function type for callback")
+ raise IOError("Input Error, expect function type for callback")
+ for key, value in data.items():
+ if isinstance(value, dict):
+ DeepDict.sectionconstraint_item_traverse(data[key], target_section, callback, key)
+ else:
+ if target_section == section:
+ callback((key, value))
+
+ @staticmethod
+ def item_traverse(data, callback=None):
+ """
+ Dictionary filter function, walks through the input dict (obj) calling the callback function for each item.
+ The callback function then is called with the key value pair as tuple input.
+ :param data: [dict] input dictionary
+ :param callback:
+ """
+ if not isinstance(data, dict):
+ LOG.error("Input Error, expect dict type for obj")
+ raise IOError("Input Error, expect dict type for obj")
+ if not isinstance(callback, types.FunctionType):
+ LOG.error("Input Error, expect function type for callback")
+ raise IOError("Input Error, expect function type for callback")
+ for key, value in data.items():
+ if isinstance(value, dict):
+ DeepDict.value_traverse(data[key], callback)
+ else:
+ callback((key, value))
+
@staticmethod
def parse_type(string):
"""
Type convert input string to float, int, list, tuple or string
:param string: [str] input string
:return: [T] converted output
"""
try:
a = float(string)
try:
b = int(string)
except ValueError:
return float(string)
if a == b:
return b
return a
except ValueError:
if string.startswith("[") and string.endswith("]"):
string = re.sub(' ', '', string)
elements = string[1:-1].split(",")
li = []
for e in elements:
li.append(DeepDict.parse_type(e))
return li
elif string.startswith("(") and string.endswith(")"):
elements = string[1:-1].split(",")
li = []
for e in elements:
li.append(DeepDict.parse_type(e))
return tuple(li)
return string
@property
def data(self):
return self._data
@data.setter
def data(self, value):
if not isinstance(value, dict):
LOG.error(f"Input Error, expect dict type for value, but got {type(value)}")
raise IOError(f"Input Error, expect dict type for value, but got {type(value)}")
self.clear()
self._data = value
@property
def sep(self):
return self._sep
@sep.setter
def sep(self, value):
if not isinstance(value, str):
LOG.error(f"Input Error, expect str type for value, but got {type(value)}")
raise IOError(f"Input Error, expect str type for value, but got {type(value)}")
self._sep = value
diff --git a/hyppopy/plugins/hyperopt_solver_plugin.py b/hyppopy/plugins/hyperopt_solver_plugin.py
index c94ae6a..7cc784f 100644
--- a/hyppopy/plugins/hyperopt_solver_plugin.py
+++ b/hyppopy/plugins/hyperopt_solver_plugin.py
@@ -1,70 +1,71 @@
# -*- coding: utf-8 -*-
#
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
import os
import logging
from hyppopy.globals import DEBUGLEVEL
LOG = logging.getLogger(os.path.basename(__file__))
LOG.setLevel(DEBUGLEVEL)
from pprint import pformat
from hyperopt import fmin, tpe, hp, STATUS_OK, STATUS_FAIL, Trials
from yapsy.IPlugin import IPlugin
+from hyppopy.projectmanager import ProjectManager
from hyppopy.solverpluginbase import SolverPluginBase
class hyperopt_Solver(SolverPluginBase, IPlugin):
trials = None
best = None
def __init__(self):
SolverPluginBase.__init__(self)
LOG.debug("initialized")
def loss_function(self, params):
try:
loss = self.loss(self.data, params)
status = STATUS_OK
except Exception as e:
LOG.error(f"execution of self.loss(self.data, params) failed due to:\n {e}")
status = STATUS_FAIL
return {'loss': loss, 'status': status}
def execute_solver(self, parameter):
LOG.debug(f"execute_solver using solution space:\n\n\t{pformat(parameter)}\n")
self.trials = Trials()
try:
self.best = fmin(fn=self.loss_function,
space=parameter,
algo=tpe.suggest,
- max_evals=self.settings.max_iterations,
+ max_evals=ProjectManager.max_iterations,
trials=self.trials)
except Exception as e:
msg = f"internal error in hyperopt.fmin occured. {e}"
LOG.error(msg)
raise BrokenPipeError(msg)
def convert_results(self):
txt = ""
solution = dict([(k, v) for k, v in self.best.items() if v is not None])
txt += 'Solution Hyperopt Plugin\n========\n'
txt += "\n".join(map(lambda x: "%s \t %s" % (x[0], str(x[1])), solution.items()))
txt += "\n"
return txt
diff --git a/hyppopy/plugins/optunity_solver_plugin.py b/hyppopy/plugins/optunity_solver_plugin.py
index c92ab52..c5e47f6 100644
--- a/hyppopy/plugins/optunity_solver_plugin.py
+++ b/hyppopy/plugins/optunity_solver_plugin.py
@@ -1,72 +1,71 @@
-# -*- coding: utf-8 -*-
-#
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
import os
import logging
from hyppopy.globals import DEBUGLEVEL
LOG = logging.getLogger(os.path.basename(__file__))
LOG.setLevel(DEBUGLEVEL)
from pprint import pformat
import optunity
from yapsy.IPlugin import IPlugin
+from hyppopy.projectmanager import ProjectManager
from hyppopy.solverpluginbase import SolverPluginBase
class optunity_Solver(SolverPluginBase, IPlugin):
solver_info = None
trials = None
best = None
status = None
def __init__(self):
SolverPluginBase.__init__(self)
LOG.debug("initialized")
def loss_function(self, **params):
try:
loss = self.loss(self.data, params)
self.status.append('ok')
return loss
except Exception as e:
LOG.error(f"computing loss failed due to:\n {e}")
self.status.append('fail')
return 1e9
def execute_solver(self, parameter):
LOG.debug(f"execute_solver using solution space:\n\n\t{pformat(parameter)}\n")
self.status = []
try:
self.best, self.trials, self.solver_info = optunity.minimize_structured(f=self.loss_function,
- num_evals=self.settings.max_iterations,
+ num_evals=ProjectManager.max_iterations,
search_space=parameter)
except Exception as e:
LOG.error(f"internal error in optunity.minimize_structured occured. {e}")
raise BrokenPipeError(f"internal error in optunity.minimize_structured occured. {e}")
def convert_results(self):
solution = dict([(k, v) for k, v in self.best.items() if v is not None])
txt = ""
txt += 'Solution Optunity Plugin\n========\n'
txt += "\n".join(map(lambda x: "%s \t %s" % (x[0], str(x[1])), solution.items()))
txt += f"\nSolver used: {self.solver_info['solver_name']}"
txt += f"\nOptimum: {self.trials.optimum}"
txt += f"\nIterations used: {self.trials.stats['num_evals']}"
txt += f"\nDuration: {self.trials.stats['time']} s\n"
return txt
diff --git a/hyppopy/projectmanager.py b/hyppopy/projectmanager.py
new file mode 100644
index 0000000..24eed29
--- /dev/null
+++ b/hyppopy/projectmanager.py
@@ -0,0 +1,55 @@
+# DKFZ
+#
+#
+# Copyright (c) German Cancer Research Center,
+# Division of Medical and Biological Informatics.
+# All rights reserved.
+#
+# This software is distributed WITHOUT ANY WARRANTY; without
+# even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE.
+#
+# See LICENSE.txt or http://www.mitk.org for details.
+#
+# Author: Sven Wanner (s.wanner@dkfz.de)
+
+from hyppopy.singleton import *
+from hyppopy.deepdict import DeepDict
+from hyppopy.globals import SETTINGSCUSTOMPATH, SETTINGSSOLVERPATH
+
+import os
+import logging
+from hyppopy.globals import DEBUGLEVEL
+LOG = logging.getLogger(os.path.basename(__file__))
+LOG.setLevel(DEBUGLEVEL)
+
+
+@singleton_object
+class ProjectManager(metaclass=Singleton):
+
+ def __init__(self):
+ self.configfilename = None
+ self.config = None
+
+ def test_config(self):
+ #TODO test the config structure to fullfill the needs, throwing useful error is not
+ return True
+
+ def read_config(self, configfile):
+ self.configfilename = configfile
+ self.config = DeepDict(configfile)
+ if not self.test_config():
+ self.configfilename = None
+ self.config = None
+ return False
+
+ try:
+ self.config.transfer_attrs(self, SETTINGSCUSTOMPATH.split("/")[-1])
+ self.config.transfer_attrs(self, SETTINGSSOLVERPATH.split("/")[-1])
+ except Exception as e:
+ msg = f"transfering custom section as class attributes failed, " \
+ f"is the config path to your custom section correct? {SETTINGSCUSTOMPATH}. Exception {e}"
+ LOG.error(msg)
+ raise LookupError(msg)
+
+ return True
diff --git a/hyppopy/settingspluginbase.py b/hyppopy/settingspluginbase.py
index 07e0d37..3d8d49a 100644
--- a/hyppopy/settingspluginbase.py
+++ b/hyppopy/settingspluginbase.py
@@ -1,89 +1,77 @@
-# -*- coding: utf-8 -*-
-#
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
import abc
import os
import logging
from hyppopy.globals import DEBUGLEVEL
LOG = logging.getLogger(os.path.basename(__file__))
LOG.setLevel(DEBUGLEVEL)
from hyppopy.globals import SETTINGSSOLVERPATH, SETTINGSCUSTOMPATH
from hyppopy.deepdict import DeepDict
class SettingsPluginBase(object):
_data = None
_name = None
def __init__(self):
- self._data = DeepDict()
+ self._data = {}
@abc.abstractmethod
def convert_parameter(self):
raise NotImplementedError('users must define convert_parameter to use this base class')
def get_hyperparameter(self):
- return self.convert_parameter(self.data["hyperparameter"])
+ return self.convert_parameter(self.data)
def set(self, data):
self.data.clear()
self.data = data
def read(self, fname):
self.data.clear()
self.data.from_file(fname)
def write(self, fname):
self.data.to_file(fname)
- def set_attributes(self, cls):
- if self.data.has_section(SETTINGSSOLVERPATH.split('/')[-1]):
- attrs_sec = self.data[SETTINGSSOLVERPATH]
- for key, value in attrs_sec.items():
- setattr(cls, key, value)
- if self.data.has_section(SETTINGSCUSTOMPATH.split('/')[-1]):
- attrs_sec = self.data[SETTINGSCUSTOMPATH]
- for key, value in attrs_sec.items():
- setattr(cls, key, value)
-
@property
def data(self):
return self._data
@data.setter
def data(self, value):
if isinstance(value, dict):
- self._data.data = value
- elif isinstance(value, DeepDict):
self._data = value
+ elif isinstance(value, DeepDict):
+ self._data = value.data
else:
raise IOError(f"unexpected input type({type(value)}) for data, needs to be of type dict or DeepDict!")
@property
def name(self):
return self._name
@name.setter
def name(self, value):
if not isinstance(value, str):
LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead")
raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead")
self._name = value
diff --git a/hyppopy/solver.py b/hyppopy/solver.py
index 78d3917..1ae9740 100644
--- a/hyppopy/solver.py
+++ b/hyppopy/solver.py
@@ -1,84 +1,77 @@
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
import os
import logging
from hyppopy.globals import DEBUGLEVEL
LOG = logging.getLogger(os.path.basename(__file__))
LOG.setLevel(DEBUGLEVEL)
class Solver(object):
_name = None
_solver_plugin = None
_settings_plugin = None
def __init__(self):
pass
def set_data(self, data):
self.solver.set_data(data)
- def set_parameters(self, params):
+ def set_hyperparameters(self, params):
self.settings.set(params)
- self.settings.set_attributes(self.solver)
- self.settings.set_attributes(self.settings)
-
- def read_parameter(self, fname):
- self.settings.read(fname)
- self.settings.set_attributes(self.solver)
- self.settings.set_attributes(self.settings)
def set_loss_function(self, loss_func):
self.solver.set_loss_function(loss_func)
def run(self):
self.solver.settings = self.settings
self.solver.run()
def get_results(self):
return self.solver.get_results()
@property
def is_ready(self):
return self.solver is not None and self.settings is not None
@property
def solver(self):
return self._solver_plugin
@solver.setter
def solver(self, value):
self._solver_plugin = value
@property
def settings(self):
return self._settings_plugin
@settings.setter
def settings(self, value):
self._settings_plugin = value
@property
def name(self):
return self._name
@name.setter
def name(self, value):
if not isinstance(value, str):
LOG.error(f"Invalid input, str type expected for value, got {type(value)} instead")
raise IOError(f"Invalid input, str type expected for value, got {type(value)} instead")
self._name = value
diff --git a/hyppopy/solverfactory.py b/hyppopy/solverfactory.py
index cfbe1ed..697c33c 100644
--- a/hyppopy/solverfactory.py
+++ b/hyppopy/solverfactory.py
@@ -1,156 +1,156 @@
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
from yapsy.PluginManager import PluginManager
from hyppopy.globals import PLUGIN_DEFAULT_DIR
from hyppopy.deepdict import DeepDict
from hyppopy.solver import Solver
from hyppopy.singleton import *
import os
import logging
from hyppopy.globals import DEBUGLEVEL
LOG = logging.getLogger(os.path.basename(__file__))
LOG.setLevel(DEBUGLEVEL)
@singleton_object
class SolverFactory(metaclass=Singleton):
"""
This class is responsible for grabbing all plugins from the plugin folder arranging them into a
Solver class instances. These Solver class instances can be requested from the factory via the
get_solver method. The SolverFactory class is a Singleton class, so try not to instantiate it using
SolverFactory(), the consequences will be horrific. Instead use is like a class having static
functions only, SolverFactory.method().
"""
_plugin_dirs = []
_plugins = {}
def __init__(self):
print("Solverfactory: I'am alive!")
self.reset()
self.load_plugins()
LOG.debug("Solverfactory initialized")
def load_plugins(self):
"""
Load plugin modules from plugin paths
"""
LOG.debug("load_plugins()")
manager = PluginManager()
LOG.debug(f"setPluginPlaces(" + " ".join(map(str, self._plugin_dirs)))
manager.setPluginPlaces(self._plugin_dirs)
manager.collectPlugins()
for plugin in manager.getAllPlugins():
name_elements = plugin.plugin_object.__class__.__name__.split("_")
LOG.debug("found plugin " + " ".join(map(str, name_elements)))
if len(name_elements) != 2 or ("Solver" not in name_elements and "Settings" not in name_elements):
LOG.error(f"invalid plugin class naming for class {plugin.plugin_object.__class__.__name__}, the convention is libname_Solver or libname_Settings.")
raise NameError(f"invalid plugin class naming for class {plugin.plugin_object.__class__.__name__}, the convention is libname_Solver or libname_Settings.")
if name_elements[0] not in self._plugins.keys():
self._plugins[name_elements[0]] = Solver()
self._plugins[name_elements[0]].name = name_elements[0]
if name_elements[1] == "Solver":
try:
obj = plugin.plugin_object.__class__()
obj.name = name_elements[0]
self._plugins[name_elements[0]].solver = obj
LOG.info(f"plugin: {name_elements[0]} Solver loaded")
except Exception as e:
LOG.error(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}")
raise ImportError(f"Failed to instanciate class {plugin.plugin_object.__class__.__name__}")
elif name_elements[1] == "Settings":
try:
obj = plugin.plugin_object.__class__()
obj.name = name_elements[0]
self._plugins[name_elements[0]].settings = obj
LOG.info(f"plugin: {name_elements[0]} ParameterSpace loaded")
except Exception as e:
LOG.error(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}")
raise ImportError(f"failed to instanciate class {plugin.plugin_object.__class__.__name__}")
else:
LOG.error(f"failed loading plugin {name_elements[0]}, please check if naming conventions are kept!")
raise IOError(f"failed loading plugin {name_elements[0]}!, please check if naming conventions are kept!")
if len(self._plugins) == 0:
msg = "no plugins found, please check your plugin folder names or your plugin scripts for errors!"
LOG.error(msg)
raise IOError(msg)
def reset(self):
"""
Reset solver factory
"""
LOG.debug("reset()")
self._plugins = {}
self._plugin_dirs = []
self.add_plugin_dir(os.path.abspath(PLUGIN_DEFAULT_DIR))
def add_plugin_dir(self, dir):
"""
Add plugin directory
"""
LOG.debug(f"add_plugin_dir({dir})")
self._plugin_dirs.append(dir)
def list_solver(self):
"""
list all solvers available
:return: [list(str)]
"""
return list(self._plugins.keys())
def from_settings(self, settings):
if isinstance(settings, dict):
tmp = DeepDict()
tmp.data = settings
settings = tmp
elif isinstance(settings, str):
if not os.path.isfile(settings):
LOG.warning(f"input error, file {settings} not found!")
settings = DeepDict(settings)
if isinstance(settings, DeepDict):
if settings.has_section("use_plugin"):
try:
use_plugin = settings["settings/solver/use_plugin"]
except Exception as e:
LOG.warning("wrong settings path for use_plugin option detected, expecting the path settings/solver/use_plugin!")
solver = self.get_solver(use_plugin)
- solver.set_parameters(settings)
+ solver.set_hyperparameters(settings['hyperparameter'])
return solver
LOG.warning("failed to choose a solver, either the config file is missing the section settings/solver/use_plugin, or there might be a typo")
else:
msg = "unknown input error, expected DeepDict, dict or filename!"
LOG.error(msg)
raise IOError(msg)
return None
def get_solver(self, name):
"""
returns a solver by name tag
:param name: [str] solver name
:return: [Solver] instance
"""
if not isinstance(name, str):
msg = f"Invalid input, str type expected for name, got {type(name)} instead"
LOG.error(msg)
raise IOError(msg)
if name not in self.list_solver():
msg = f"failed solver request, a solver called {name} is not available, " \
f"check for typo or if your plugin failed while loading!"
LOG.error(msg)
raise LookupError(msg)
LOG.debug(f"get_solver({name})")
return self._plugins[name]
diff --git a/hyppopy/solverpluginbase.py b/hyppopy/solverpluginbase.py
index 63a16d7..e880e80 100644
--- a/hyppopy/solverpluginbase.py
+++ b/hyppopy/solverpluginbase.py
@@ -1,86 +1,84 @@
-# -*- coding: utf-8 -*-
-#
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
import abc
import os
import logging
from hyppopy.globals import DEBUGLEVEL
from hyppopy.settingspluginbase import SettingsPluginBase
LOG = logging.getLogger(os.path.basename(__file__))
LOG.setLevel(DEBUGLEVEL)
class SolverPluginBase(object):
data = None
loss = None
_settings = None
_name = None
def __init__(self):
pass
@abc.abstractmethod
def loss_function(self, params):
raise NotImplementedError('users must define loss_func to use this base class')
@abc.abstractmethod
def execute_solver(self):
raise NotImplementedError('users must define execute_solver to use this base class')
@abc.abstractmethod
def convert_results(self):
raise NotImplementedError('users must define convert_results to use this base class')
def set_data(self, data):
self.data = data
def set_loss_function(self, func):
self.loss = func
def get_results(self):
return self.convert_results()
def run(self):
self.execute_solver(self.settings.get_hyperparameter())
@property
def name(self):
return self._name
@name.setter
def name(self, value):
if not isinstance(value, str):
msg = f"Invalid input, str type expected for value, got {type(value)} instead"
LOG.error(msg)
raise IOError(msg)
self._name = value
@property
def settings(self):
return self._settings
@settings.setter
def settings(self, value):
if not isinstance(value, SettingsPluginBase):
msg = f"Invalid input, SettingsPluginBase type expected for value, got {type(value)} instead"
LOG.error(msg)
raise IOError(msg)
self._settings = value
diff --git a/hyppopy/tests/data/Iris/rf_config.json b/hyppopy/tests/data/Iris/rf_config.json
index 869c319..443e510 100644
--- a/hyppopy/tests/data/Iris/rf_config.json
+++ b/hyppopy/tests/data/Iris/rf_config.json
@@ -1,42 +1,43 @@
{"hyperparameter": {
"n_estimators": {
"domain": "uniform",
"data": "[3,500]",
"type": "int"
},
"criterion": {
"domain": "categorical",
"data": "[gini,entropy]",
"type": "str"
},
"max_depth": {
"domain": "uniform",
"data": "[3, 50]",
"type": "int"
},
"min_samples_split": {
"domain": "uniform",
"data": "[0.0001,1]",
"type": "float"
},
"min_samples_leaf": {
"domain": "uniform",
"data": "[0.0001,0.5]",
"type": "float"
},
"max_features": {
"domain": "categorical",
"data": "[auto,sqrt,log2]",
"type": "str"
}
},
"settings": {
"solver": {
"max_iterations": "3",
"use_plugin" : "optunity"
},
"custom": {
+ "data_path": "D:/Projects/Python/hyppopy/hyppopy/tests/data/Iris",
"data_name": "train_data.npy",
"labels_name": "train_labels.npy"
}
}}
\ No newline at end of file
diff --git a/hyppopy/tests/data/Iris/rf_config.xml b/hyppopy/tests/data/Iris/rf_config.xml
index 925d164..b60530d 100644
--- a/hyppopy/tests/data/Iris/rf_config.xml
+++ b/hyppopy/tests/data/Iris/rf_config.xml
@@ -1,44 +1,45 @@
uniform
[3,200]
int
categorical
[gini,entropy]
str
uniform
[3, 50]
int
uniform
[0.0001,1]
float
uniform
[0.0001,0.5]
float
categorical
[auto,sqrt,log2]
str
3
optunity
+ D:/Projects/Python/hyppopy/hyppopy/tests/data/Iris
train_data.npy
train_labels.npy
\ No newline at end of file
diff --git a/hyppopy/tests/data/Iris/svc_config.json b/hyppopy/tests/data/Iris/svc_config.json
index 0628f97..59fb433 100644
--- a/hyppopy/tests/data/Iris/svc_config.json
+++ b/hyppopy/tests/data/Iris/svc_config.json
@@ -1,32 +1,33 @@
{"hyperparameter": {
"C": {
"domain": "uniform",
"data": "[0,20]",
"type": "float"
},
"gamma": {
"domain": "uniform",
"data": "[0.0001,20.0]",
"type": "float"
},
"kernel": {
"domain": "categorical",
"data": "[linear, sigmoid, poly, rbf]",
"type": "str"
},
"decision_function_shape": {
"domain": "categorical",
"data": "[ovo,ovr]",
"type": "str"
}
},
"settings": {
"solver": {
"max_iterations": "3",
"use_plugin" : "optunity"
},
"custom": {
+ "data_path": "D:/Projects/Python/hyppopy/hyppopy/tests/data/Iris",
"data_name": "train_data.npy",
"labels_name": "train_labels.npy"
}
}}
\ No newline at end of file
diff --git a/hyppopy/tests/data/Iris/svc_config.xml b/hyppopy/tests/data/Iris/svc_config.xml
index fb4f50b..f8ab4e3 100644
--- a/hyppopy/tests/data/Iris/svc_config.xml
+++ b/hyppopy/tests/data/Iris/svc_config.xml
@@ -1,34 +1,35 @@
uniform
[0,20]
float
uniform
[0.0001,20.0]
float
categorical
[linear,sigmoid,poly,rbf]
str
categorical
[ovo,ovr]
str
3
hyperopt
+ D:/Projects/Python/hyppopy/hyppopy/tests/data/Iris
train_data.npy
train_labels.npy
\ No newline at end of file
diff --git a/hyppopy/tests/data/Titanic/rf_config.json b/hyppopy/tests/data/Titanic/rf_config.json
index 7993c78..a637f35 100644
--- a/hyppopy/tests/data/Titanic/rf_config.json
+++ b/hyppopy/tests/data/Titanic/rf_config.json
@@ -1,27 +1,28 @@
{"hyperparameter": {
"n_estimators": {
"domain": "uniform",
"data": "[3,500]",
"type": "int"
},
"criterion": {
"domain": "categorical",
"data": "[gini,entropy]",
"type": "str"
},
"max_depth": {
"domain": "uniform",
"data": "[3, 50]",
"type": "int"
}
},
"settings": {
"solver": {
"max_iterations": "3",
"use_plugin" : "optunity"
},
"custom": {
+ "data_path": "D:/Projects/Python/hyppopy/hyppopy/tests/data/Titanic",
"data_name": "train_cleaned.csv",
"labels_name": "Survived"
}
}}
\ No newline at end of file
diff --git a/hyppopy/tests/data/Titanic/rf_config.xml b/hyppopy/tests/data/Titanic/rf_config.xml
index 5dd0797..fbfa828 100644
--- a/hyppopy/tests/data/Titanic/rf_config.xml
+++ b/hyppopy/tests/data/Titanic/rf_config.xml
@@ -1,29 +1,30 @@
uniform
[3,200]
int
categorical
[gini,entropy]
str
uniform
[3, 50]
int
3
optunity
+ D:/Projects/Python/hyppopy/hyppopy/tests/data/Titanic
train_cleaned.csv
Survived
\ No newline at end of file
diff --git a/hyppopy/tests/data/Titanic/svc_config.json b/hyppopy/tests/data/Titanic/svc_config.json
index 3291024..4066bb6 100644
--- a/hyppopy/tests/data/Titanic/svc_config.json
+++ b/hyppopy/tests/data/Titanic/svc_config.json
@@ -1,32 +1,33 @@
{"hyperparameter": {
"C": {
"domain": "uniform",
"data": "[0,20]",
"type": "float"
},
"gamma": {
"domain": "uniform",
"data": "[0.0001,20.0]",
"type": "float"
},
"kernel": {
"domain": "categorical",
"data": "[linear, sigmoid, poly, rbf]",
"type": "str"
},
"decision_function_shape": {
"domain": "categorical",
"data": "[ovo,ovr]",
"type": "str"
}
},
"settings": {
"solver": {
"max_iterations": "3",
"use_plugin" : "hyperopt"
},
"custom": {
+ "data_path": "D:/Projects/Python/hyppopy/hyppopy/tests/data/Titanic",
"data_name": "train_cleaned.csv",
"labels_name": "Survived"
}
}}
\ No newline at end of file
diff --git a/hyppopy/tests/data/Titanic/svc_config.xml b/hyppopy/tests/data/Titanic/svc_config.xml
index b26c191..1107c4a 100644
--- a/hyppopy/tests/data/Titanic/svc_config.xml
+++ b/hyppopy/tests/data/Titanic/svc_config.xml
@@ -1,34 +1,35 @@
uniform
[0,20]
float
uniform
[0.0001,20.0]
float
categorical
[linear,sigmoid,poly,rbf]
str
categorical
[ovo,ovr]
str
3
optunity
+ D:/Projects/Python/hyppopy/hyppopy/tests/data/Titanic
train_cleaned.csv
Survived
\ No newline at end of file
diff --git a/hyppopy/tests/test_deepdict.py b/hyppopy/tests/test_deepdict.py
index 0868cf5..fc5efe8 100644
--- a/hyppopy/tests/test_deepdict.py
+++ b/hyppopy/tests/test_deepdict.py
@@ -1,153 +1,163 @@
-# -*- coding: utf-8 -*-
-#
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
import os
import unittest
from hyppopy.deepdict import DeepDict
DATA_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
class DeepDictTestSuite(unittest.TestCase):
def setUp(self):
self.test_data = {
'widget': {
'debug': 'on',
'image': {'alignment': 'center',
'hOffset': 250,
'name': 'sun1',
'src': 'Images/Sun.png',
'vOffset': 250},
'text': {'alignment': 'center',
'data': 'Click Here',
'hOffset': 250,
'name': 'text1',
'onMouseUp': 'sun1.opacity = (sun1.opacity / 100) * 90;',
'size': 36,
'style': 'bold',
'vOffset': 100},
'window': {'height': 500,
'name': 'main_window',
'title': 'Sample Konfabulator Widget',
'width': 500}
}
}
self.test_data2 = {"test": {
"section": {
"var1": 100,
"var2": 200
}
}}
def test_fileIO(self):
dd_json = DeepDict(os.path.join(DATA_PATH, 'test_json.json'))
dd_xml = DeepDict(os.path.join(DATA_PATH, 'test_xml.xml'))
dd_dict = DeepDict(self.test_data)
self.assertTrue(list(self.test_data.keys())[0] == list(dd_json.data.keys())[0])
self.assertTrue(list(self.test_data.keys())[0] == list(dd_xml.data.keys())[0])
self.assertTrue(list(self.test_data.keys())[0] == list(dd_dict.data.keys())[0])
for key in self.test_data['widget'].keys():
self.assertTrue(self.test_data['widget'][key] == dd_json.data['widget'][key])
self.assertTrue(self.test_data['widget'][key] == dd_xml.data['widget'][key])
self.assertTrue(self.test_data['widget'][key] == dd_dict.data['widget'][key])
for key in self.test_data['widget'].keys():
if key == 'debug':
self.assertTrue(dd_json.data['widget']["debug"] == "on")
self.assertTrue(dd_xml.data['widget']["debug"] == "on")
self.assertTrue(dd_dict.data['widget']["debug"] == "on")
else:
for key2, value2 in self.test_data['widget'][key].items():
self.assertTrue(value2 == dd_json.data['widget'][key][key2])
self.assertTrue(value2 == dd_xml.data['widget'][key][key2])
self.assertTrue(value2 == dd_dict.data['widget'][key][key2])
dd_dict.to_file(os.path.join(DATA_PATH, 'write_to_json_test.json'))
dd_dict.to_file(os.path.join(DATA_PATH, 'write_to_xml_test.xml'))
self.assertTrue(os.path.isfile(os.path.join(DATA_PATH, 'write_to_json_test.json')))
self.assertTrue(os.path.isfile(os.path.join(DATA_PATH, 'write_to_xml_test.xml')))
dd_json = DeepDict(os.path.join(DATA_PATH, 'write_to_json_test.json'))
dd_xml = DeepDict(os.path.join(DATA_PATH, 'write_to_xml_test.xml'))
self.assertTrue(dd_json == dd_dict)
self.assertTrue(dd_xml == dd_dict)
try:
os.remove(os.path.join(DATA_PATH, 'write_to_json_test.json'))
os.remove(os.path.join(DATA_PATH, 'write_to_xml_test.xml'))
except Exception as e:
print(e)
print("Warning: Failed to delete temporary data during tests!")
def test_has_section(self):
dd = DeepDict(self.test_data)
self.assertTrue(dd.has_section('hOffset'))
self.assertTrue(dd.has_section('window'))
self.assertTrue(dd.has_section('widget'))
self.assertTrue(dd.has_section('style'))
self.assertTrue(dd.has_section('window'))
self.assertTrue(dd.has_section('title'))
self.assertFalse(dd.has_section('notasection'))
def test_data_access(self):
dd = DeepDict(self.test_data)
self.assertEqual(dd['widget/window/height'], 500)
self.assertEqual(dd['widget/image/name'], 'sun1')
self.assertTrue(isinstance(dd['widget/window'], dict))
self.assertEqual(len(dd['widget/window']), 4)
dd = DeepDict(path_sep=".")
dd.data = self.test_data
self.assertEqual(dd['widget.window.height'], 500)
self.assertEqual(dd['widget.image.name'], 'sun1')
self.assertTrue(isinstance(dd['widget.window'], dict))
self.assertEqual(len(dd['widget.window']), 4)
def test_data_adding(self):
dd = DeepDict()
dd["test/section/var1"] = 100
dd["test/section/var2"] = 200
self.assertTrue(dd.data == self.test_data2)
dd = DeepDict()
dd["test"] = {}
dd["test/section"] = {}
dd["test/section/var1"] = 100
dd["test/section/var2"] = 200
self.assertTrue(dd.data == self.test_data2)
def test_sample_space(self):
dd = DeepDict(os.path.join(DATA_PATH, 'test_paramset.json'))
self.assertEqual(len(dd[['parameter', 'activation', 'data']]), 4)
self.assertEqual(dd['parameter/activation/data'], ['ReLU', 'tanh', 'sigm', 'ELU'])
self.assertTrue(isinstance(dd['parameter/activation/data'], list))
self.assertTrue(isinstance(dd['parameter/activation/data'][0], str))
self.assertEqual(dd['parameter/layerdepth/data'], [3, 20])
self.assertTrue(isinstance(dd['parameter/layerdepth/data'], list))
self.assertTrue(isinstance(dd['parameter/layerdepth/data'][0], int))
self.assertTrue(isinstance(dd['parameter/learningrate/data'][0], float))
self.assertEqual(dd['parameter/learningrate/data'][0], 1e-5)
self.assertEqual(dd['parameter/learningrate/data'][1], 10.0)
def test_len(self):
dd = DeepDict(os.path.join(DATA_PATH, 'test_paramset.json'))
self.assertEqual(len(dd), 1)
+ def test_setattr(self):
+ dd = DeepDict(os.path.join(DATA_PATH, 'iris_svc_parameter.xml'))
+
+ class Foo(object):
+ def __init__(self):
+ pass
+ foo = Foo
+ dd.transfer_attrs(foo, 'solver')
+ self.assertEqual(foo.max_iterations, 50)
+ self.assertEqual(foo.use_plugin, 'optunity')
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/hyppopy/tests/test_projectmanager.py b/hyppopy/tests/test_projectmanager.py
new file mode 100644
index 0000000..0e52fe1
--- /dev/null
+++ b/hyppopy/tests/test_projectmanager.py
@@ -0,0 +1,38 @@
+# DKFZ
+#
+#
+# Copyright (c) German Cancer Research Center,
+# Division of Medical and Biological Informatics.
+# All rights reserved.
+#
+# This software is distributed WITHOUT ANY WARRANTY; without
+# even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE.
+#
+# See LICENSE.txt or http://www.mitk.org for details.
+#
+# Author: Sven Wanner (s.wanner@dkfz.de)
+
+import os
+import unittest
+from hyppopy.projectmanager import ProjectManager
+
+
+DATA_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
+
+
+class ProjectManagerTestSuite(unittest.TestCase):
+
+ def setUp(self):
+ pass
+
+ def test_attr_transfer(self):
+ ProjectManager.read_config(os.path.join(DATA_PATH, *('Titanic', 'rf_config.xml')))
+ self.assertEqual(ProjectManager.data_name, 'train_cleaned.csv')
+ self.assertEqual(ProjectManager.labels_name, 'Survived')
+ self.assertEqual(ProjectManager.max_iterations, 3)
+ self.assertEqual(ProjectManager.use_plugin, 'optunity')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/hyppopy/tests/test_workflows.py b/hyppopy/tests/test_workflows.py
index 8071a1a..6b2c848 100644
--- a/hyppopy/tests/test_workflows.py
+++ b/hyppopy/tests/test_workflows.py
@@ -1,130 +1,96 @@
-# -*- coding: utf-8 -*-
-#
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
import os
import unittest
from hyppopy.globals import TESTDATA_DIR
IRIS_DATA = os.path.join(TESTDATA_DIR, 'Iris')
TITANIC_DATA = os.path.join(TESTDATA_DIR, 'Titanic')
+from hyppopy.projectmanager import ProjectManager
from hyppopy.workflows.svc_usecase.svc_usecase import svc_usecase
from hyppopy.workflows.randomforest_usecase.randomforest_usecase import randomforest_usecase
-class Args(object):
-
- def __init__(self):
- pass
-
- def set_arg(self, name, value):
- setattr(self, name, value)
-
-
class WorkflowTestSuite(unittest.TestCase):
def setUp(self):
self.results = []
def test_workflow_svc_on_iris_from_xml(self):
- svc_args_xml = Args()
- svc_args_xml.set_arg('plugin', '')
- svc_args_xml.set_arg('data', IRIS_DATA)
- svc_args_xml.set_arg('config', os.path.join(IRIS_DATA, 'svc_config.xml'))
- uc = svc_usecase(svc_args_xml)
+ ProjectManager.read_config(os.path.join(IRIS_DATA, 'svc_config.xml'))
+ uc = svc_usecase()
uc.run()
self.results.append(uc.get_results())
self.assertTrue(uc.get_results().find("Solution") != -1)
def test_workflow_rf_on_iris_from_xml(self):
- rf_args_xml = Args()
- rf_args_xml.set_arg('plugin', '')
- rf_args_xml.set_arg('data', IRIS_DATA)
- rf_args_xml.set_arg('config', os.path.join(IRIS_DATA, 'rf_config.xml'))
- uc = svc_usecase(rf_args_xml)
+ ProjectManager.read_config(os.path.join(IRIS_DATA, 'rf_config.xml'))
+ uc = svc_usecase()
uc.run()
self.results.append(uc.get_results())
self.assertTrue(uc.get_results().find("Solution") != -1)
def test_workflow_svc_on_iris_from_json(self):
- svc_args_json = Args()
- svc_args_json.set_arg('plugin', '')
- svc_args_json.set_arg('data', IRIS_DATA)
- svc_args_json.set_arg('config', os.path.join(IRIS_DATA, 'svc_config.json'))
- uc = svc_usecase(svc_args_json)
+ ProjectManager.read_config(os.path.join(IRIS_DATA, 'svc_config.json'))
+ uc = svc_usecase()
uc.run()
self.results.append(uc.get_results())
self.assertTrue(uc.get_results().find("Solution") != -1)
def test_workflow_rf_on_iris_from_json(self):
- rf_args_json = Args()
- rf_args_json.set_arg('plugin', '')
- rf_args_json.set_arg('data', IRIS_DATA)
- rf_args_json.set_arg('config', os.path.join(IRIS_DATA, 'rf_config.json'))
- uc = randomforest_usecase(rf_args_json)
+ ProjectManager.read_config(os.path.join(IRIS_DATA, 'rf_config.json'))
+ uc = randomforest_usecase()
uc.run()
self.results.append(uc.get_results())
self.assertTrue(uc.get_results().find("Solution") != -1)
- def test_workflow_svc_on_titanic_from_xml(self):
- svc_args_xml = Args()
- svc_args_xml.set_arg('plugin', '')
- svc_args_xml.set_arg('data', TITANIC_DATA)
- svc_args_xml.set_arg('config', os.path.join(TITANIC_DATA, 'svc_config.xml'))
- uc = svc_usecase(svc_args_xml)
- uc.run()
- self.results.append(uc.get_results())
- self.assertTrue(uc.get_results().find("Solution") != -1)
+ # def test_workflow_svc_on_titanic_from_xml(self):
+ # ProjectManager.read_config(os.path.join(TITANIC_DATA, 'svc_config.xml'))
+ # uc = svc_usecase()
+ # uc.run()
+ # self.results.append(uc.get_results())
+ # self.assertTrue(uc.get_results().find("Solution") != -1)
def test_workflow_rf_on_titanic_from_xml(self):
- rf_args_xml = Args()
- rf_args_xml.set_arg('plugin', '')
- rf_args_xml.set_arg('data', TITANIC_DATA)
- rf_args_xml.set_arg('config', os.path.join(TITANIC_DATA, 'rf_config.xml'))
- uc = randomforest_usecase(rf_args_xml)
+ ProjectManager.read_config(os.path.join(TITANIC_DATA, 'rf_config.xml'))
+ uc = randomforest_usecase()
uc.run()
self.results.append(uc.get_results())
self.assertTrue(uc.get_results().find("Solution") != -1)
- def test_workflow_svc_on_titanic_from_json(self):
- svc_args_json = Args()
- svc_args_json.set_arg('plugin', '')
- svc_args_json.set_arg('data', TITANIC_DATA)
- svc_args_json.set_arg('config', os.path.join(TITANIC_DATA, 'svc_config.json'))
- uc = svc_usecase(svc_args_json)
- uc.run()
- self.results.append(uc.get_results())
- self.assertTrue(uc.get_results().find("Solution") != -1)
+ # def test_workflow_svc_on_titanic_from_json(self):
+ # ProjectManager.read_config(os.path.join(TITANIC_DATA, 'svc_config.json'))
+ # uc = svc_usecase()
+ # uc.run()
+ # self.results.append(uc.get_results())
+ # self.assertTrue(uc.get_results().find("Solution") != -1)
def test_workflow_rf_on_titanic_from_json(self):
- rf_args_json = Args()
- rf_args_json.set_arg('plugin', '')
- rf_args_json.set_arg('data', TITANIC_DATA)
- rf_args_json.set_arg('config', os.path.join(TITANIC_DATA, 'rf_config.json'))
- uc = randomforest_usecase(rf_args_json)
+ ProjectManager.read_config(os.path.join(TITANIC_DATA, 'rf_config.json'))
+ uc = randomforest_usecase()
uc.run()
self.results.append(uc.get_results())
self.assertTrue(uc.get_results().find("Solution") != -1)
def tearDown(self):
print("")
for r in self.results:
print(r)
if __name__ == '__main__':
unittest.main()
diff --git a/hyppopy/workflows/dataloader/__init__.py b/hyppopy/workflows/dataloader/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/hyppopy/workflows/dataloader/dataloaderbase.py b/hyppopy/workflows/dataloader/dataloaderbase.py
new file mode 100644
index 0000000..83cd117
--- /dev/null
+++ b/hyppopy/workflows/dataloader/dataloaderbase.py
@@ -0,0 +1,36 @@
+# DKFZ
+#
+#
+# Copyright (c) German Cancer Research Center,
+# Division of Medical and Biological Informatics.
+# All rights reserved.
+#
+# This software is distributed WITHOUT ANY WARRANTY; without
+# even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE.
+#
+# See LICENSE.txt or http://www.mitk.org for details.
+#
+# Author: Sven Wanner (s.wanner@dkfz.de)
+
+import abc
+
+
+class DataLoaderBase(object):
+
+ def __init__(self):
+ self.data = None
+
+ def start(self, **kwargs):
+ self.read(**kwargs)
+ if self.data is None:
+ raise AttributeError("data is empty, did you missed to assign it while implementing read...?")
+ self.preprocess(**kwargs)
+
+ @abc.abstractmethod
+ def read(self, **kwargs):
+ raise NotImplementedError("the read method has to be implemented in classes derived from DataLoader")
+
+ @abc.abstractmethod
+ def preprocess(self, **kwargs):
+ pass
diff --git a/hyppopy/workflows/dataloader/simpleloader.py b/hyppopy/workflows/dataloader/simpleloader.py
new file mode 100644
index 0000000..6760cfd
--- /dev/null
+++ b/hyppopy/workflows/dataloader/simpleloader.py
@@ -0,0 +1,41 @@
+# DKFZ
+#
+#
+# Copyright (c) German Cancer Research Center,
+# Division of Medical and Biological Informatics.
+# All rights reserved.
+#
+# This software is distributed WITHOUT ANY WARRANTY; without
+# even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE.
+#
+# See LICENSE.txt or http://www.mitk.org for details.
+#
+# Author: Sven Wanner (s.wanner@dkfz.de)
+
+import os
+import numpy as np
+import pandas as pd
+
+from hyppopy.workflows.dataloader.dataloaderbase import DataLoaderBase
+
+
+class SimpleDataLoaderBase(DataLoaderBase):
+
+ def read(self, **kwargs):
+ if kwargs['data_name'].endswith(".npy"):
+ if not kwargs['labels_name'].endswith(".npy"):
+ raise IOError("Expect both data_name and labels_name being of type .npy!")
+ self.data = [np.load(os.path.join(kwargs['path'], kwargs['data_name'])), np.load(os.path.join(kwargs['path'], kwargs['labels_name']))]
+ elif kwargs['data_name'].endswith(".csv"):
+ try:
+ dataset = pd.read_csv(os.path.join(kwargs['path'], kwargs['data_name']))
+ y = dataset[kwargs['labels_name']].values
+ X = dataset.drop([kwargs['labels_name']], axis=1).values
+ self.data = [X, y]
+ except Exception as e:
+ print("Precondition violation, this usage case expects as data_name a "
+ "csv file and as label_name a name of a column in this csv table!")
+ else:
+ raise NotImplementedError("This combination of data_name and labels_name "
+ "does not yet exist, feel free to add it")
diff --git a/hyppopy/workflows/dataloader/unetloader.py b/hyppopy/workflows/dataloader/unetloader.py
new file mode 100644
index 0000000..9d71dcb
--- /dev/null
+++ b/hyppopy/workflows/dataloader/unetloader.py
@@ -0,0 +1,79 @@
+# DKFZ
+#
+#
+# Copyright (c) German Cancer Research Center,
+# Division of Medical and Biological Informatics.
+# All rights reserved.
+#
+# This software is distributed WITHOUT ANY WARRANTY; without
+# even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE.
+#
+# See LICENSE.txt or http://www.mitk.org for details.
+#
+# Author: Sven Wanner (s.wanner@dkfz.de)
+
+import os
+from collections import defaultdict
+from hyppopy.workflows.dataloader.dataloaderbase import DataLoaderBase
+
+
+class UnetDataLoaderBase(DataLoaderBase):
+
+ def read(self, **kwargs):
+ pass
+
+ def subfiles(self, folder, join=True, prefix=None, suffix=None, sort=True):
+ if join:
+ l = os.path.join
+ else:
+ l = lambda x, y: y
+ res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
+ and (prefix is None or i.startswith(prefix))
+ and (suffix is None or i.endswith(suffix))]
+ if sort:
+ res.sort()
+ return res
+
+ def preprocess(self, **kwargs):
+ image_dir = os.path.join(kwargs['root_dir'], kwargs['image_dir'])
+ label_dir = os.path.join(kwargs['root_dir'], kwargs['labels_dir'])
+ output_dir = os.path.join(kwargs['root_dir'], kwargs['output_dir'])
+ classes = kwargs['classes']
+
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ print('Created' + output_dir + '...')
+
+ class_stats = defaultdict(int)
+ total = 0
+
+ nii_files = self.subfiles(image_dir, suffix=".nii.gz", join=False)
+
+ for i in range(0, len(nii_files)):
+ if nii_files[i].startswith("._"):
+ nii_files[i] = nii_files[i][2:]
+
+ for f in nii_files:
+ image, _ = load(os.path.join(image_dir, f))
+ label, _ = load(os.path.join(label_dir, f.replace('_0000', '')))
+
+ print(f)
+
+ for i in range(classes):
+ class_stats[i] += np.sum(label == i)
+ total += np.sum(label == i)
+
+ image = (image - image.min()) / (image.max() - image.min())
+
+ image = reshape(image, append_value=0, new_shape=(64, 64, 64))
+ label = reshape(label, append_value=0, new_shape=(64, 64, 64))
+
+ result = np.stack((image, label))
+
+ np.save(os.path.join(output_dir, f.split('.')[0] + '.npy'), result)
+ print(f)
+
+ print(total)
+ for i in range(classes):
+ print(class_stats[i], class_stats[i] / total)
diff --git a/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py b/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py
index c7ca0bc..5c14b6e 100644
--- a/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py
+++ b/hyppopy/workflows/randomforest_usecase/randomforest_usecase.py
@@ -1,38 +1,38 @@
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
-from hyppopy.workflows.workflowbase import Workflow
-from hyppopy.workflows.datalaoder.simpleloader import SimpleDataLoader
+from hyppopy.projectmanager import ProjectManager
+from hyppopy.workflows.workflowbase import WorkflowBase
+from hyppopy.workflows.dataloader.simpleloader import SimpleDataLoaderBase
-class randomforest_usecase(Workflow):
-
- def __init__(self, args):
- Workflow.__init__(self, args)
+class randomforest_usecase(WorkflowBase):
def setup(self):
- dl = SimpleDataLoader()
- dl.read(path=self.args.data, data_name=self.solver.settings.data_name, labels_name=self.solver.settings.labels_name)
- self.solver.set_data(dl.get())
+ dl = SimpleDataLoaderBase()
+ dl.start(path=ProjectManager.data_path,
+ data_name=ProjectManager.data_name,
+ labels_name=ProjectManager.labels_name)
+ self.solver.set_data(dl.data)
def blackbox_function(self, data, params):
if "n_estimators" in params.keys():
params["n_estimators"] = int(round(params["n_estimators"]))
clf = RandomForestClassifier(**params)
return -cross_val_score(estimator=clf, X=data[0], y=data[1], cv=3).mean()
diff --git a/hyppopy/workflows/svc_usecase/svc_usecase.py b/hyppopy/workflows/svc_usecase/svc_usecase.py
index 18f93f1..4108969 100644
--- a/hyppopy/workflows/svc_usecase/svc_usecase.py
+++ b/hyppopy/workflows/svc_usecase/svc_usecase.py
@@ -1,41 +1,38 @@
-# -*- coding: utf-8 -*-
-#
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
import os
import numpy as np
import pandas as pd
from sklearn.svm import SVC
from sklearn.model_selection import cross_val_score
-from hyppopy.workflows.workflowbase import Workflow
-from hyppopy.workflows.datalaoder.simpleloader import SimpleDataLoader
-
+from hyppopy.projectmanager import ProjectManager
+from hyppopy.workflows.workflowbase import WorkflowBase
+from hyppopy.workflows.dataloader.simpleloader import SimpleDataLoaderBase
-class svc_usecase(Workflow):
- def __init__(self, args):
- Workflow.__init__(self, args)
+class svc_usecase(WorkflowBase):
def setup(self):
- dl = SimpleDataLoader()
- dl.read(path=self.args.data, data_name=self.solver.settings.data_name,
- labels_name=self.solver.settings.labels_name)
- self.solver.set_data(dl.get())
+ dl = SimpleDataLoaderBase()
+ dl.start(path=ProjectManager.data_path,
+ data_name=ProjectManager.data_name,
+ labels_name=ProjectManager.labels_name)
+ self.solver.set_data(dl.data)
def blackbox_function(self, data, params):
clf = SVC(**params)
return -cross_val_score(estimator=clf, X=data[0], y=data[1], cv=3).mean()
diff --git a/hyppopy/workflows/unet_usecase/unet_uscase_utils.py b/hyppopy/workflows/unet_usecase/unet_uscase_utils.py
index ae1f4ce..9b9147c 100644
--- a/hyppopy/workflows/unet_usecase/unet_uscase_utils.py
+++ b/hyppopy/workflows/unet_usecase/unet_uscase_utils.py
@@ -1,419 +1,417 @@
-# -*- coding: utf-8 -*-
-#
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
import os
import torch
import pickle
import fnmatch
import numpy as np
from torch import nn
from medpy.io import load
from collections import defaultdict
from abc import ABCMeta, abstractmethod
def sum_tensor(input, axes, keepdim=False):
axes = np.unique(axes).astype(int)
if keepdim:
for ax in axes:
input = input.sum(int(ax), keepdim=True)
else:
for ax in sorted(axes, reverse=True):
input = input.sum(int(ax))
return input
def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None):
if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]:
rebalance_weights = rebalance_weights[1:] # this is the case when use_bg=False
axes = tuple([0] + list(range(2, len(net_output.size()))))
tp = sum_tensor(net_output * gt, axes, keepdim=False)
fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False)
fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False)
weights = torch.ones(tp.shape)
weights[0] = background_weight
if net_output.device.type == "cuda":
weights = weights.cuda(net_output.device.index)
if rebalance_weights is not None:
rebalance_weights = torch.from_numpy(rebalance_weights).float()
if net_output.device.type == "cuda":
rebalance_weights = rebalance_weights.cuda(net_output.device.index)
tp = tp * rebalance_weights
fn = fn * rebalance_weights
result = (- ((2 * tp + smooth_in_nom) / (2 * tp + fp + fn + smooth)) * weights).mean()
return result
def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1.):
axes = tuple(range(2, len(net_output.size())))
intersect = sum_tensor(net_output * gt, axes, keepdim=False)
denom = sum_tensor(net_output + gt, axes, keepdim=False)
result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth)) * weights).mean() #TODO: Was ist weights and er Stelle?
return result
class SoftDiceLoss(nn.Module):
def __init__(self, smooth=1., apply_nonlin=None, batch_dice=False, do_bg=True, smooth_in_nom=True, background_weight=1, rebalance_weights=None):
"""
hahaa no documentation for you today
:param smooth:
:param apply_nonlin:
:param batch_dice:
:param do_bg:
:param smooth_in_nom:
:param background_weight:
:param rebalance_weights:
"""
super(SoftDiceLoss, self).__init__()
if not do_bg:
assert background_weight == 1, "if there is no bg, then set background weight to 1 you dummy"
self.rebalance_weights = rebalance_weights
self.background_weight = background_weight
self.smooth_in_nom = smooth_in_nom
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.y_onehot = None
if not smooth_in_nom:
self.nom_smooth = 0
else:
self.nom_smooth = smooth
def forward(self, x, y):
with torch.no_grad():
y = y.long()
shp_x = x.shape
shp_y = y.shape
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
if len(shp_x) != len(shp_y):
y = y.view((shp_y[0], 1, *shp_y[1:]))
# now x and y should have shape (B, C, X, Y(, Z))) and (B, 1, X, Y(, Z))), respectively
y_onehot = torch.zeros(shp_x)
if x.device.type == "cuda":
y_onehot = y_onehot.cuda(x.device.index)
y_onehot.scatter_(1, y, 1)
if not self.do_bg:
x = x[:, 1:]
y_onehot = y_onehot[:, 1:]
if not self.batch_dice:
if self.background_weight != 1 or (self.rebalance_weights is not None):
raise NotImplementedError("nah son")
l = soft_dice(x, y_onehot, self.smooth, self.smooth_in_nom)
else:
l = soft_dice_per_batch_2(x, y_onehot, self.smooth, self.smooth_in_nom,
background_weight=self.background_weight,
rebalance_weights=self.rebalance_weights)
return l
def load_dataset(base_dir, pattern='*.npy', slice_offset=5, keys=None):
fls = []
files_len = []
slices_ax = []
for root, dirs, files in os.walk(base_dir):
i = 0
for filename in sorted(fnmatch.filter(files, pattern)):
if keys is not None and filename[:-4] in keys:
npy_file = os.path.join(root, filename)
numpy_array = np.load(npy_file, mmap_mode="r")
fls.append(npy_file)
files_len.append(numpy_array.shape[1])
slices_ax.extend([(i, j) for j in range(slice_offset, files_len[-1] - slice_offset)])
i += 1
return fls, files_len, slices_ax,
class SlimDataLoaderBase(object):
def __init__(self, data, batch_size, number_of_threads_in_multithreaded=None):
"""
Slim version of DataLoaderBase (which is now deprecated). Only provides very simple functionality.
You must derive from this class to implement your own DataLoader. You must overrive self.generate_train_batch()
If you use our MultiThreadedAugmenter you will need to also set and use number_of_threads_in_multithreaded. See
multithreaded_dataloading in examples!
:param data: will be stored in self._data. You can use it to generate your batches in self.generate_train_batch()
:param batch_size: will be stored in self.batch_size for use in self.generate_train_batch()
:param number_of_threads_in_multithreaded: will be stored in self.number_of_threads_in_multithreaded.
None per default. If you wish to iterate over all your training data only once per epoch, you must coordinate
your Dataloaders and you will need this information
"""
__metaclass__ = ABCMeta
self.number_of_threads_in_multithreaded = number_of_threads_in_multithreaded
self._data = data
self.batch_size = batch_size
self.thread_id = 0
def set_thread_id(self, thread_id):
self.thread_id = thread_id
def __iter__(self):
return self
def __next__(self):
return self.generate_train_batch()
@abstractmethod
def generate_train_batch(self):
'''override this
Generate your batch from self._data .Make sure you generate the correct batch size (self.BATCH_SIZE)
'''
pass
class NumpyDataLoader(SlimDataLoaderBase):
def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000,
seed=None, file_pattern='*.npy', label_slice=1, input_slice=(0,), keys=None):
self.files, self.file_len, self.slices = load_dataset(base_dir=base_dir, pattern=file_pattern, slice_offset=0, keys=keys, )
super(NumpyDataLoader, self).__init__(self.slices, batch_size, num_batches)
self.batch_size = batch_size
self.use_next = False
if mode == "train":
self.use_next = False
self.slice_idxs = list(range(0, len(self.slices)))
self.data_len = len(self.slices)
self.num_batches = min((self.data_len // self.batch_size)+10, num_batches)
if isinstance(label_slice, int):
label_slice = (label_slice,)
self.input_slice = input_slice
self.label_slice = label_slice
self.np_data = np.asarray(self.slices)
def reshuffle(self):
print("Reshuffle...")
random.shuffle(self.slice_idxs)
print("Initializing... this might take a while...")
def generate_train_batch(self):
open_arr = random.sample(self._data, self.batch_size)
return self.get_data_from_array(open_arr)
def __len__(self):
n_items = min(self.data_len // self.batch_size, self.num_batches)
return n_items
def __getitem__(self, item):
slice_idxs = self.slice_idxs
data_len = len(self.slices)
np_data = self.np_data
if item > len(self):
raise StopIteration()
if (item * self.batch_size) == data_len:
raise StopIteration()
start_idx = (item * self.batch_size) % data_len
stop_idx = ((item + 1) * self.batch_size) % data_len
if ((item + 1) * self.batch_size) == data_len:
stop_idx = data_len
if stop_idx > start_idx:
idxs = slice_idxs[start_idx:stop_idx]
else:
raise StopIteration()
open_arr = np_data[idxs]
return self.get_data_from_array(open_arr)
def get_data_from_array(self, open_array):
data = []
fnames = []
slice_idxs = []
labels = []
for slice in open_array:
fn_name = self.files[slice[0]]
numpy_array = np.load(fn_name, mmap_mode="r")
numpy_slice = numpy_array[ :, slice[1], ]
data.append(numpy_slice[None, self.input_slice[0]]) # 'None' keeps the dimension
if self.label_slice is not None:
labels.append(numpy_slice[None, self.label_slice[0]]) # 'None' keeps the dimension
fnames.append(self.files[slice[0]])
slice_idxs.append(slice[1])
ret_dict = {'data': np.asarray(data), 'fnames': fnames, 'slice_idxs': slice_idxs}
if self.label_slice is not None:
ret_dict['seg'] = np.asarray(labels)
return ret_dict
class NumpyDataSet(object):
"""
TODO
"""
def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, num_processes=8, num_cached_per_queue=8 * 4, target_size=128,
file_pattern='*.npy', label_slice=1, input_slice=(0,), do_reshuffle=True, keys=None):
data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, seed=seed, file_pattern=file_pattern,
input_slice=input_slice, label_slice=label_slice, keys=keys)
self.data_loader = data_loader
self.batch_size = batch_size
self.do_reshuffle = do_reshuffle
self.number_of_slices = 1
self.transforms = get_transforms(mode=mode, target_size=target_size)
self.augmenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=num_processes,
num_cached_per_queue=num_cached_per_queue, seeds=seed,
shuffle=do_reshuffle)
self.augmenter.restart()
def __len__(self):
return len(self.data_loader)
def __iter__(self):
if self.do_reshuffle:
self.data_loader.reshuffle()
self.augmenter.renew()
return self.augmenter
def __next__(self):
return next(self.augmenter)
def reshape(orig_img, append_value=-1024, new_shape=(512, 512, 512)):
reshaped_image = np.zeros(new_shape)
reshaped_image[...] = append_value
x_offset = 0
y_offset = 0 # (new_shape[1] - orig_img.shape[1]) // 2
z_offset = 0 # (new_shape[2] - orig_img.shape[2]) // 2
reshaped_image[x_offset:orig_img.shape[0] + x_offset, y_offset:orig_img.shape[1] + y_offset,
z_offset:orig_img.shape[2] + z_offset] = orig_img
# insert temp_img.min() as background value
return reshaped_image
def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
if join:
l = os.path.join
else:
l = lambda x, y: y
res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
and (prefix is None or i.startswith(prefix))
and (suffix is None or i.endswith(suffix))]
if sort:
res.sort()
return res
def preprocess_data(root_dir):
print("preprocess data...")
image_dir = os.path.join(root_dir, 'imagesTr')
print(f"image_dir: {image_dir}")
label_dir = os.path.join(root_dir, 'labelsTr')
print(f"label_dir: {label_dir}")
output_dir = os.path.join(root_dir, 'preprocessed')
print(f"output_dir: {output_dir} ... ", end="")
classes = 3
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print("created!")
else:
print("found!\npreprocessed data already available, aborted preprocessing!")
return False
print("start preprocessing ... ", end="")
class_stats = defaultdict(int)
total = 0
nii_files = subfiles(image_dir, suffix=".nii.gz", join=False)
for i in range(0, len(nii_files)):
if nii_files[i].startswith("._"):
nii_files[i] = nii_files[i][2:]
for i, f in enumerate(nii_files):
image, _ = load(os.path.join(image_dir, f))
label, _ = load(os.path.join(label_dir, f.replace('_0000', '')))
for i in range(classes):
class_stats[i] += np.sum(label == i)
total += np.sum(label == i)
image = (image - image.min()) / (image.max() - image.min())
image = reshape(image, append_value=0, new_shape=(64, 64, 64))
label = reshape(label, append_value=0, new_shape=(64, 64, 64))
result = np.stack((image, label))
np.save(os.path.join(output_dir, f.split('.')[0] + '.npy'), result)
print("finished!")
return True
def create_splits(output_dir, image_dir):
print("creating splits ... ", end="")
npy_files = subfiles(image_dir, suffix=".npy", join=False)
trainset_size = len(npy_files) * 50 // 100
valset_size = len(npy_files) * 25 // 100
testset_size = len(npy_files) * 25 // 100
splits = []
for split in range(0, 5):
image_list = npy_files.copy()
trainset = []
valset = []
testset = []
for i in range(0, trainset_size):
patient = np.random.choice(image_list)
image_list.remove(patient)
trainset.append(patient[:-4])
for i in range(0, valset_size):
patient = np.random.choice(image_list)
image_list.remove(patient)
valset.append(patient[:-4])
for i in range(0, testset_size):
patient = np.random.choice(image_list)
image_list.remove(patient)
testset.append(patient[:-4])
split_dict = dict()
split_dict['train'] = trainset
split_dict['val'] = valset
split_dict['test'] = testset
splits.append(split_dict)
with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f:
pickle.dump(splits, f)
print("finished!")
diff --git a/hyppopy/workflows/unet_usecase/unet_usecase.py b/hyppopy/workflows/unet_usecase/unet_usecase.py
index 550400d..4666f20 100644
--- a/hyppopy/workflows/unet_usecase/unet_usecase.py
+++ b/hyppopy/workflows/unet_usecase/unet_usecase.py
@@ -1,126 +1,33 @@
-# -*- coding: utf-8 -*-
-#
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
import os
-import pickle
-
-import torch
-import torch.optim as optim
-from torch.optim.lr_scheduler import ReduceLROnPlateau
-import torch.nn.functional as F
-from networks.RecursiveUNet import UNet
-
-import hyppopy.solverfactory as sfac
-from .unet_uscase_utils import *
-
-
-def unet_usecase(args):
- print("Execute UNet UseCase...")
- data_dir = args.data
- preprocessed_dir = os.path.join(args.data, 'preprocessed')
- solver_plugin = args.plugin
- config_file = args.config
- print(f"input data directory: {data_dir}")
- print(f"use plugin: {solver_plugin}")
- print(f"config file: {config_file}")
-
- factory = sfac.SolverFactory.instance()
- solver = factory.get_solver(solver_plugin)
- solver.read_parameter(config_file)
-
- if preprocess_data(data_dir):
- create_splits(output_dir=data_dir, image_dir=preprocessed_dir)
-
- with open(os.path.join(data_dir, "splits.pkl"), 'rb') as f:
- splits = pickle.load(f)
-
- tr_keys = splits[solver.settings.fold]['train']
- val_keys = splits[solver.settings.fold]['val']
- test_keys = splits[solver.settings.fold]['test']
-
- def loss_function(patch_size, batch_size):
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- train_data_loader = NumpyDataSet(data_dir,
- target_size=patch_size,
- batch_size=batch_size,
- keys=tr_keys)
- val_data_loader = NumpyDataSet(data_dir,
- target_size=patch_size,
- batch_size=batch_size,
- mode="val",
- do_reshuffle=False)
- model = UNet(num_classes=solver.settings.num_classes, in_channels=solver.settings.in_channels)
- model.to(device)
-
- dice_loss = SoftDiceLoss(batch_dice=True)
- ce_loss = torch.nn.CrossEntropyLoss()
- node_optimizer = optim.Adam(model.parameters(), lr=solver.settings.learning_rate)
- scheduler = ReduceLROnPlateau(node_optimizer, 'min')
-
- model.train()
-
- data = None
- batch_counter = 0
- for data_batch in train_data_loader:
-
- node_optimizer.zero_grad()
-
- data = data_batch['data'][0].float().to(device)
- target = data_batch['seg'][0].long().to(device)
-
- pred = model(data)
- pred_softmax = F.softmax(pred, dim=1)
-
- loss = dice_loss(pred_softmax, target.squeeze()) + ce_loss(pred, target.squeeze())
- loss.backward()
- node_optimizer.step()
- batch_counter += 1
-
- assert data is not None, 'data is None. Please check if your dataloader works properly'
-
- model.eval()
-
- data = None
- loss_list = []
-
- with torch.no_grad():
- for data_batch in val_data_loader:
- data = data_batch['data'][0].float().to(device)
- target = data_batch['seg'][0].long().to(device)
-
- pred = model(data)
- pred_softmax = F.softmax(pred)
-
- loss = dice_loss(pred_softmax, target.squeeze()) + ce_loss(pred, target.squeeze())
- loss_list.append(loss.item())
-
- assert data is not None, 'data is None. Please check if your dataloader works properly'
- scheduler.step(np.mean(loss_list))
-
- data = []
-
- # solver.set_data(data)
- # solver.read_parameter(config_file)
- # solver.set_loss_function(loss_function)
- # solver.run()
- # solver.get_results()
-
+import numpy as np
+import pandas as pd
+from sklearn.svm import SVC
+from sklearn.model_selection import cross_val_score
+from hyppopy.projectmanager import ProjectManager
+from hyppopy.workflows.workflowbase import WorkflowBase
+from hyppopy.workflows.dataloader.unetloader import UnetDataLoaderBase
+class unet_usecase(WorkflowBase):
+ def setup(self):
+ pass
+ def blackbox_function(self, data, params):
+ pass
diff --git a/hyppopy/workflows/workflowbase.py b/hyppopy/workflows/workflowbase.py
index fea491b..61c879b 100644
--- a/hyppopy/workflows/workflowbase.py
+++ b/hyppopy/workflows/workflowbase.py
@@ -1,76 +1,61 @@
# -*- coding: utf-8 -*-
#
# DKFZ
#
#
# Copyright (c) German Cancer Research Center,
# Division of Medical and Biological Informatics.
# All rights reserved.
#
# This software is distributed WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.
#
# See LICENSE.txt or http://www.mitk.org for details.
#
# Author: Sven Wanner (s.wanner@dkfz.de)
-from hyppopy.solverfactory import SolverFactory
from hyppopy.deepdict import DeepDict
+from hyppopy.solverfactory import SolverFactory
+from hyppopy.projectmanager import ProjectManager
from hyppopy.globals import SETTINGSCUSTOMPATH, SETTINGSSOLVERPATH
import os
import abc
import logging
from hyppopy.globals import DEBUGLEVEL
LOG = logging.getLogger(os.path.basename(__file__))
LOG.setLevel(DEBUGLEVEL)
-class Workflow(object):
- _solver = None
- _args = None
+class WorkflowBase(object):
- def __init__(self, args):
- self._args = args
- if args.plugin is None or args.plugin == '':
- dd = DeepDict(args.config)
- ppath = "use_plugin"
- if not dd.has_section(ppath):
- msg = f"invalid config file, missing section {ppath}"
- LOG.error(msg)
- raise LookupError(msg)
- plugin = dd[SETTINGSSOLVERPATH+'/'+ppath]
- else:
- plugin = args.plugin
- self._solver = SolverFactory.get_solver(plugin)
- self.solver.read_parameter(args.config)
+ def __init__(self):
+ self._solver = SolverFactory.get_solver(ProjectManager.use_plugin)
+ self.solver.set_hyperparameters(ProjectManager.config['hyperparameter'])
def run(self):
self.setup()
self.solver.set_loss_function(self.blackbox_function)
self.solver.run()
self.test()
def get_results(self):
return self.solver.get_results()
@abc.abstractmethod
def setup(self):
raise NotImplementedError('the user has to implement this function')
@abc.abstractmethod
def blackbox_function(self):
raise NotImplementedError('the user has to implement this function')
@abc.abstractmethod
def test(self):
pass
@property
def solver(self):
return self._solver
- @property
- def args(self):
- return self._args