diff --git a/examples/tutorial_hyppopyprojectclass.py b/examples/tutorial_hyppopyprojectclass.py new file mode 100644 index 0000000..846a1e5 --- /dev/null +++ b/examples/tutorial_hyppopyprojectclass.py @@ -0,0 +1,56 @@ +# In this tutorial we demonstrate the HyppopyProject class usage + +# import the HyppopyProject class +from hyppopy.HyppopyProject import HyppopyProject + +# To configure a solver we need to instanciate a HyppopyProject class. +# This class can be configured using a nested dict. This dict has two +# obligatory sections, hyperparameter and settings. A hyperparameter +# is described using a dict containing a section, data and type field +# and thus the hyperparameter section is a collection of hyperparameter +# dicts. The settings section keeps solver settings. These might depend +# on the solver used and need to be checked for each. E.g. Randomsearch, +# Hyperopt and Optunity need a solver setting max_iterations, the Grid- +# searchSolver don't. +config = { +"hyperparameter": { + "C": { + "domain": "uniform", + "data": [0.0001, 20], + "type": "float" + }, + "gamma": { + "domain": "uniform", + "data": [0.0001, 20.0], + "type": "float" + }, + "kernel": { + "domain": "categorical", + "data": ["linear", "sigmoid", "poly", "rbf"], + "type": "str" + } +}, +"settings": { + "solver": { + "max_iterations": 500 + }, + "custom": {} +}} + +# When creating a HyppopyProject instance we +# pass the config dictionary to the constructor. +project = HyppopyProject(config=config) + +# When building the project programmatically we can also use the methods +# add_hyperparameter and add_settings +project.clear() +project.add_hyperparameter(name="C", domain="uniform", data=[0.0001, 20], dtype="float") +project.add_hyperparameter(name="kernel", domain="categorical", data=["linear", "sigmoid"], dtype="str") +project.add_settings(section="solver", name="max_iterations", value=500) + +# The custom section can be used freely +project.add_settings(section="custom", name="my_var", value=10) + +# Settings are automatically transformed to member variables of the project class with the section as prefix +if project.solver_max_iterations < 1000 and project.custom_my_var == 10: + print("Project configured!") diff --git a/hyppopy/HyppopyProject.py b/hyppopy/HyppopyProject.py index 6ca2295..fdb0a86 100644 --- a/hyppopy/HyppopyProject.py +++ b/hyppopy/HyppopyProject.py @@ -1,74 +1,96 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) from hyppopy.globals import * LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class HyppopyProject(object): def __init__(self, config=None): - self._hyperparameter = None - self._settings = None + self._hyperparameter = {} + self._settings = {} self._extmembers = [] if config is not None: self.set_config(config) def clear(self): - self._hyperparameter = None - self._settings = None + self._hyperparameter = {} + self._settings = {} for added in self._extmembers: if added in self.__dict__.keys(): del self.__dict__[added] self._extmembers = [] def set_config(self, config): self.clear() assert isinstance(config, dict), "Input Error, config of type {} not supported!".format(type(config)) assert HYPERPARAMETERPATH in config.keys(), "Missing hyperparameter section in config dict" assert SETTINGSPATH in config.keys(), "Missing settings section in config dict" self._hyperparameter = config[HYPERPARAMETERPATH] self._settings = config[SETTINGSPATH] self.parse_members() + def add_hyperparameter(self, name, domain, data, dtype): + assert isinstance(name, str), "precondition violation, name of type {} not allowed, expect str!".format(type(name)) + assert isinstance(domain, str), "precondition violation, domain of type {} not allowed, expect str!".format(type(domain)) + assert domain in SUPPORTED_DOMAINS, "domain {} not supported, expect {}!".format(domain, SUPPORTED_DOMAINS) + assert isinstance(data, list) or isinstance(data, tuple), "precondition violation, data of type {} not allowed, expect list or tuple!".format(type(data)) + if domain != "categorical": + assert len(data) == 3 or len(data) == 2, "precondition violation, data must be a list of len 2 or 3" + assert isinstance(dtype, str), "precondition violation, dtype of type {} not allowed, expect str!".format(type(dtype)) + assert dtype in SUPPORTED_DTYPES, "precondition violation, dtype {} not supported, expect {}!".format(dtype, SUPPORTED_DTYPES) + self._hyperparameter[name] = {"domain": domain, "data": data, "type": dtype} + + def add_settings(self, section, name, value): + assert isinstance(section, str), "precondition violation, section of type {} not allowed, expect str!".format(type(section)) + assert isinstance(name, str), "precondition violation, name of type {} not allowed, expect str!".format(type(name)) + if section not in self._settings.keys(): + self._settings[section] = {} + self._settings[section][name] = value + self.parse_members() + def parse_members(self): for section_name, content in self.settings.items(): for name, value in content.items(): member_name = section_name + "_" + name - setattr(self, member_name, value) - self._extmembers.append(member_name) + if member_name not in self._extmembers: + setattr(self, member_name, value) + self._extmembers.append(member_name) + else: + self.__dict__[member_name] = value def get_typeof(self, hyperparametername): if not hyperparametername in self.hyperparameter.keys(): return None dtype = self.hyperparameter[hyperparametername]["type"] if dtype == 'str': return str if dtype == 'int': return int if dtype == 'float': return float @property def hyperparameter(self): return self._hyperparameter @property def settings(self): return self._settings diff --git a/hyppopy/globals.py b/hyppopy/globals.py index d7814d0..3875fe7 100644 --- a/hyppopy/globals.py +++ b/hyppopy/globals.py @@ -1,31 +1,34 @@ # DKFZ # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. import os import sys import logging ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT) LIBNAME = "hyppopy" TESTDATA_DIR = os.path.join(ROOT, *(LIBNAME, "tests", "data")) HYPERPARAMETERPATH = "hyperparameter" SETTINGSPATH = "settings" VFUNCDATAPATH = os.path.join(os.path.join(ROOT, LIBNAME), "virtualparameterspace") +SUPPORTED_DOMAINS = ["uniform", "normal", "loguniform", "categorical"] +SUPPORTED_DTYPES = ["int", "float", "str"] + DEFAULTITERATIONS = 500 LOGFILENAME = os.path.join(ROOT, '{}_log.log'.format(LIBNAME)) DEBUGLEVEL = logging.DEBUG logging.basicConfig(filename=LOGFILENAME, filemode='w', format='%(levelname)s: %(name)s - %(message)s') diff --git a/hyppopy/tests/test_hyppopyproject.py b/hyppopy/tests/test_hyppopyproject.py index c48473f..049cf6c 100644 --- a/hyppopy/tests/test_hyppopyproject.py +++ b/hyppopy/tests/test_hyppopyproject.py @@ -1,73 +1,97 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import unittest import numpy as np from hyppopy.HyppopyProject import HyppopyProject from hyppopy.globals import TESTDATA_DIR def foo(a, b): return a + b class VirtualFunctionTestSuite(unittest.TestCase): def setUp(self): pass def test_project_creation(self): config = { "hyperparameter": { "C": { "domain": "uniform", "data": [0.0001, 20], "type": "float" }, "kernel": { "domain": "categorical", "data": ["linear", "sigmoid", "poly", "rbf"], "type": "str" } }, "settings": { "solver": { "max_iterations": 300 }, "custom": { "param1": 1, "param2": 2, "function": foo } }} project = HyppopyProject() project.set_config(config) self.assertEqual(project.hyperparameter["C"]["domain"], "uniform") self.assertEqual(project.hyperparameter["C"]["data"], [0.0001, 20]) self.assertEqual(project.hyperparameter["C"]["type"], "float") self.assertEqual(project.hyperparameter["kernel"]["domain"], "categorical") self.assertEqual(project.hyperparameter["kernel"]["data"], ["linear", "sigmoid", "poly", "rbf"]) self.assertEqual(project.hyperparameter["kernel"]["type"], "str") self.assertEqual(project.solver_max_iterations, 300) self.assertEqual(project.custom_param1, 1) self.assertEqual(project.custom_param2, 2) self.assertEqual(project.custom_function(2, 3), 5) self.assertTrue(project.get_typeof("C") is float) self.assertTrue(project.get_typeof("kernel") is str) + + project.clear() + self.assertTrue(len(project.hyperparameter) == 0) + self.assertTrue(len(project.settings) == 0) + self.assertTrue("solver_max_iterations" not in project.__dict__.keys()) + self.assertTrue("custom_param1" not in project.__dict__.keys()) + self.assertTrue("custom_param2" not in project.__dict__.keys()) + self.assertTrue("custom_function" not in project.__dict__.keys()) + + project.add_hyperparameter(name="C", domain="uniform", data=[0.0001, 20], dtype="float") + project.add_hyperparameter(name="kernel", domain="categorical", data=["linear", "sigmoid", "poly", "rbf"], dtype="str") + + self.assertEqual(project.hyperparameter["C"]["domain"], "uniform") + self.assertEqual(project.hyperparameter["C"]["data"], [0.0001, 20]) + self.assertEqual(project.hyperparameter["C"]["type"], "float") + self.assertEqual(project.hyperparameter["kernel"]["domain"], "categorical") + self.assertEqual(project.hyperparameter["kernel"]["data"], ["linear", "sigmoid", "poly", "rbf"]) + self.assertEqual(project.hyperparameter["kernel"]["type"], "str") + + project.add_settings("solver", "max_iterations", 500) + self.assertEqual(project.solver_max_iterations, 500) + project.add_settings("solver", "max_iterations", 200) + self.assertEqual(project.solver_max_iterations, 200) +