diff --git a/examples/tutorial_gridsearch.py b/examples/tutorial_gridsearch.py index cbead25..56093cb 100644 --- a/examples/tutorial_gridsearch.py +++ b/examples/tutorial_gridsearch.py @@ -1,129 +1,129 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical Image Computing. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE # In this tutorial we solve an optimization problem using the GridsearchSolver # Gridsearch is very inefficient a Randomsearch might most of the time be the # better choice. # import the HyppopyProject class keeping track of inputs from hyppopy.HyppopyProject import HyppopyProject # import the GridsearchSolver classes -from hyppopy.solver.GridsearchSolver import GridsearchSolver +from hyppopy.solvers.GridsearchSolver import GridsearchSolver # import the Blackboxfunction class wrapping your problem for Hyppopy from hyppopy.BlackboxFunction import BlackboxFunction # To configure the GridsearchSolver we only need the hyperparameter section. Another # difference to the other solvers is that we need to define a gridsampling in addition # to the range: 'data': [0, 1, 100] which means sampling the space from 0 to 1 in 100 # intervals. Gridsearch also supports categorical, uniform, normal and lognormal sampling config = { "hyperparameter": { "C": { "domain": "uniform", "data": [0.0001, 20, 20], "type": "float" }, "gamma": { "domain": "uniform", "data": [0.0001, 20.0, 20], "type": "float" }, "kernel": { "domain": "categorical", "data": ["linear", "sigmoid", "poly", "rbf"], "type": "str" } }, "settings": { "solver": {}, "custom": {} }} # When creating a HyppopyProject instance we # pass the config dictionary to the constructor. project = HyppopyProject(config=config) # Hyppopy offers a class called BlackboxFunction to wrap your problem for Hyppopy. # The function signature is as follows: # BlackboxFunction(blackbox_func=None, # dataloader_func=None, # preprocess_func=None, # callback_func=None, # data=None, # **kwargs) # # Means we can set a couple of function pointers, a data object and an arbitrary number of custom parameter via kwargs. # # - blackbox_func: a function pointer to the actual, user defined, blackbox function that is computing our loss # - dataloader_func: a function pointer to a function handling the dataloading # - preprocess_func: a function pointer to a function automatically executed before starting the optimization process # - callback_func: a function pointer to a function that is called after each iteration with the trail object as input # - data: setting data can be done via dataloader_func or directly # - kwargs are passed to all functions above and thus can be used for parameter sharing between the functions # # (more details see in the documentation) # # Below we demonstrate the usage of all the above by defining a my_dataloader_function which in fact only grabs the # iris dataset from sklearn and returns it. A my_preprocess_function which also does nothing useful here but # demonstrating that a custom parameter can be set via kwargs and used in all of our functions when called within # Hyppopy. The my_callback_function gets as input the dictionary containing the state of the iteration and thus can be # used to access the current state of each solver iteration. Finally we define the actual loss_function # my_loss_function, which gets as input a data object and params. Both parameter are fixed, the first is defined by # the user depending on what is dataloader returns or the data object set in the constructor, the second is a dictionary # with a sample of your hyperparameter space which content is in the choice of the solver. from sklearn.svm import SVC from sklearn.datasets import load_iris from sklearn.model_selection import cross_val_score def my_dataloader_function(**kwargs): print("Dataloading...") iris_data = load_iris() return [iris_data.data, iris_data.target] def my_callback_function(**kwargs): print("\r{}".format(kwargs), end="") def my_loss_function(data, params): clf = SVC(**params) return -cross_val_score(estimator=clf, X=data[0], y=data[1], cv=3).mean() # We now create the BlackboxFunction object and pass all function pointers defined above, # as well as 2 dummy parameter (my_preproc_param, my_dataloader_input) for demonstration purposes. blackbox = BlackboxFunction(blackbox_func=my_loss_function, dataloader_func=my_dataloader_function, callback_func=my_callback_function) # create a solver instance solver = GridsearchSolver(project) # pass the loss function to the solver solver.blackbox = blackbox # run the solver solver.run() # get the result via get_result() which returns a pandas dataframe # containing the complete history and a dict best containing the # best parameter set. df, best = solver.get_results() print("\n") print("*"*100) print("Best Parameter Set:\n{}".format(best)) print("*"*100) diff --git a/hyppopy/HyppopyProject.py b/hyppopy/HyppopyProject.py index ad9dc20..d9eaab2 100644 --- a/hyppopy/HyppopyProject.py +++ b/hyppopy/HyppopyProject.py @@ -1,116 +1,117 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical Image Computing. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE import warnings from hyppopy.globals import * LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class HyppopyProject(object): def __init__(self, config=None): self._hyperparameter = {} self._settings = {} self._extmembers = [] if config is not None: self.set_config(config) def clear(self): self._hyperparameter = {} self._settings = {} for added in self._extmembers: if added in self.__dict__.keys(): del self.__dict__[added] self._extmembers = [] def set_config(self, config): self.clear() assert isinstance(config, dict), "Input Error, config of type {} not supported!".format(type(config)) assert HYPERPARAMETERPATH in config.keys(), "Missing hyperparameter section in config dict" #assert SETTINGSPATH in config.keys(), "Missing settings section in config dict" # if not SETTINGSPATH in config.keys(): # config[SETTINGSPATH] = {"solver": {"max_iterations": DEFAULTITERATIONS}} # msg = "config dict had no section {0}/solver/max_iterations, set default value: {1}".format(SETTINGSPATH, DEFAULTITERATIONS) # warnings.warn(msg) # LOG.warning(msg) # elif not "max_iterations" in config[SETTINGSPATH].keys(): # config[SETTINGSPATH]["solver"] = {"max_iterations": DEFAULTITERATIONS} # msg = "config dict had no section {0}/solver/max_iterations, set default value: {1}".format(SETTINGSPATH, DEFAULTITERATIONS) # warnings.warn(msg) # LOG.warning(msg) self._hyperparameter = config[HYPERPARAMETERPATH] - self._settings = config[SETTINGSPATH] + if SETTINGSPATH in config.keys(): + self._settings = config[SETTINGSPATH] self.parse_members() def add_hyperparameter(self, **kwargs): assert 'name' in kwargs.keys(), "precondition violation, obligatory parameter name not found!" assert 'domain' in kwargs.keys(), "precondition violation, obligatory parameter domain not found!" assert 'data' in kwargs.keys(), "precondition violation, obligatory parameter data not found!" assert 'dtype' in kwargs.keys(), "precondition violation, obligatory parameter dtype not found!" name = kwargs['name']; del kwargs['name'] domain = kwargs['domain']; del kwargs['domain'] data = kwargs['data']; del kwargs['data'] dtype = kwargs['dtype']; del kwargs['dtype'] assert isinstance(name, str), "precondition violation, name of type {} not allowed, expect str!".format(type(name)) assert isinstance(domain, str), "precondition violation, domain of type {} not allowed, expect str!".format(type(domain)) assert domain in SUPPORTED_DOMAINS, "domain {} not supported, expect {}!".format(domain, SUPPORTED_DOMAINS) assert isinstance(data, list) or isinstance(data, tuple), "precondition violation, data of type {} not allowed, expect list or tuple!".format(type(data)) if domain != "categorical": assert len(data) == 3 or len(data) == 2, "precondition violation, data must be a list of len 2 or 3" assert isinstance(dtype, str), "precondition violation, dtype of type {} not allowed, expect str!".format(type(dtype)) assert dtype in SUPPORTED_DTYPES, "precondition violation, dtype {} not supported, expect {}!".format(dtype, SUPPORTED_DTYPES) self._hyperparameter[name] = {"domain": domain, "data": data, "type": dtype} for key, value in kwargs.items(): self._hyperparameter[name][key] = value def add_settings(self, section, name, value): assert isinstance(section, str), "precondition violation, section of type {} not allowed, expect str!".format(type(section)) assert isinstance(name, str), "precondition violation, name of type {} not allowed, expect str!".format(type(name)) if section not in self._settings.keys(): self._settings[section] = {} self._settings[section][name] = value self.parse_members() def parse_members(self): for section_name, content in self.settings.items(): for name, value in content.items(): member_name = section_name + "_" + name if member_name not in self._extmembers: setattr(self, member_name, value) self._extmembers.append(member_name) else: self.__dict__[member_name] = value def get_typeof(self, hyperparametername): if not hyperparametername in self.hyperparameter.keys(): return None dtype = self.hyperparameter[hyperparametername]["type"] if dtype == 'str': return str if dtype == 'int': return int if dtype == 'float': return float @property def hyperparameter(self): return self._hyperparameter @property def settings(self): return self._settings diff --git a/hyppopy/globals.py b/hyppopy/globals.py index 313456d..ff68aba 100644 --- a/hyppopy/globals.py +++ b/hyppopy/globals.py @@ -1,36 +1,36 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical Image Computing. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE import os import sys import logging ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT) LIBNAME = "hyppopy" TESTDATA_DIR = os.path.join(ROOT, *(LIBNAME, "tests", "data")) HYPERPARAMETERPATH = "hyperparameter" SETTINGSPATH = "settings" VFUNCDATAPATH = os.path.join(os.path.join(ROOT, LIBNAME), "virtualparameterspace") SUPPORTED_DOMAINS = ["uniform", "normal", "loguniform", "categorical"] SUPPORTED_DTYPES = ["int", "float", "str"] -DEFAULTITERATIONS = 500 +#DEFAULTITERATIONS = 500 DEFAULTGRIDFREQUENCY = 10 LOGFILENAME = os.path.join(ROOT, '{}_log.log'.format(LIBNAME)) DEBUGLEVEL = logging.DEBUG logging.basicConfig(filename=LOGFILENAME, filemode='w', format='%(levelname)s: %(name)s - %(message)s') diff --git a/hyppopy/solvers/HyppopySolver.py b/hyppopy/solvers/HyppopySolver.py index 957ddaf..03a6fce 100644 --- a/hyppopy/solvers/HyppopySolver.py +++ b/hyppopy/solvers/HyppopySolver.py @@ -1,341 +1,338 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical Image Computing. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE import abc import os import copy import types import logging import datetime import numpy as np import pandas as pd from hyperopt import Trials from hyppopy.globals import DEBUGLEVEL from hyppopy.VisdomViewer import VisdomViewer from hyppopy.HyppopyProject import HyppopyProject from hyppopy.BlackboxFunction import BlackboxFunction from hyppopy.VirtualFunction import VirtualFunction -from hyppopy.globals import DEBUGLEVEL, DEFAULTITERATIONS +from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) class HyppopySolver(object): """ The HyppopySolver class is the base class for all solver addons. It defines virtual functions a child class has to implement to deal with the front-end communication, orchestrating the optimization process and ensuring a proper process information storing. The key idea is that the HyppopySolver class defines an interface to configure and run an object instance of itself independently from the concrete solver lib used to optimize in the background. To achieve this goal an addon developer needs to implement the abstract methods 'convert_searchspace', 'execute_solver' and 'loss_function_call'. These methods abstract the peculiarities of the solver libs to offer, on the user side, a simple and consistent parameter space configuration and optimization procedure. The method 'convert_searchspace' transforms the hyppopy parameter space description into the solver lib specific description. The method loss_function_call is used to handle solver lib specifics of calling the actual blackbox function and execute_solver is executed when the run method is invoked und takes care of calling the solver lib solving routine. """ def __init__(self, project=None): - self._idx = None - self._best = None - self._trials = None - self._blackbox = None - self._max_iterations = None - self._project = project - self._total_duration = None - self._solver_overhead = None - self._time_per_iteration = None - self._accumulated_blackbox_time = None - self._has_maxiteration_field = True - self._visdom_viewer = None + self._idx = None # current iteration counter + self._best = None # best parameter set + self._trials = None # trials object, hyppopy uses the Trials object from hyperopt + self._blackbox = None # blackbox function, eiter a function or a BlackboxFunction instance + self._max_iterations = None # number of iteration the solver is doing at max + self._project = project # HyppopyProject instance + self._total_duration = None # keeps track of the solvers running time + self._solver_overhead = None # stores the time overhead of the solver, means total time minus time in blackbox + self._time_per_iteration = None # mean time per iterration + self._accumulated_blackbox_time = None # summed time the solver was in the blackbox function + self._has_maxiteration_field = True # this variable has to be set to False if the solver doesn't make use of max_iterations + self._visdom_viewer = None # visdom viewer instance @abc.abstractmethod def convert_searchspace(self, hyperparameter): """ This function gets the unified hyppopy-like parameterspace description as input and, if necessary, should convert it into a solver lib specific format. The function is invoked when run is called and what it returns is passed as searchspace argument to the function execute_solver. :param hyperparameter: [dict] nested parameter description dict e.g. {'name': {'domain':'uniform', 'data':[0,1], 'type':'float'}, ...} :return: [object] converted hyperparameter space """ raise NotImplementedError('users must define convert_searchspace to use this class') @abc.abstractmethod def execute_solver(self, searchspace): """ This function is called immediatly after convert_searchspace and get the output of the latter as input. It's purpose is to call the solver libs main optimization function. :param searchspace: converted hyperparameter space """ raise NotImplementedError('users must define execute_solver to use this class') @abc.abstractmethod def loss_function_call(self, params): """ This function is called within the function loss_function and encapsulates the actual blackbox function call in each iteration. The function loss_function takes care of the iteration driving and reporting, but each solver lib might need some special treatment between the parameter set selection and the calling of the actual blackbox function, e.g. parameter converting. :param params: [dict] hyperparameter space sample e.g. {'p1': 0.123, 'p2': 3.87, ...} :return: [float] loss """ raise NotImplementedError('users must define convert_searchspace to use this class') def loss_function(self, **params): """ This function is called each iteration with a selected parameter set. The parameter set selection is driven by the solver lib itself. The purpose of this function is to take care of the iteration reporting and the calling of the callback_func is available. As a developer you might want to overwrite this function completely (e.g. HyperoptSolver) but then you need to take care for iteration reporting for yourself. The alternative is to only implement loss_function_call (e.g. OptunitySolver). :param params: [dict] hyperparameter space sample e.g. {'p1': 0.123, 'p2': 3.87, ...} :return: [float] loss """ self._idx += 1 vals = {} idx = {} for key, value in params.items(): vals[key] = [value] idx[key] = [self._idx] trial = {'tid': self._idx, 'result': {'loss': None, 'status': 'ok'}, 'misc': { 'tid': self._idx, 'idxs': idx, 'vals': vals }, 'book_time': datetime.datetime.now(), 'refresh_time': None } try: loss = self.loss_function_call(params) trial['result']['loss'] = loss trial['result']['status'] = 'ok' if loss == np.nan: trial['result']['status'] = 'failed' except Exception as e: LOG.error("computing loss failed due to:\n {}".format(e)) loss = np.nan trial['result']['loss'] = np.nan trial['result']['status'] = 'failed' trial['refresh_time'] = datetime.datetime.now() self._trials.trials.append(trial) cbd = copy.deepcopy(params) cbd['iterations'] = self._idx cbd['loss'] = loss cbd['status'] = trial['result']['status'] cbd['book_time'] = trial['book_time'] cbd['refresh_time'] = trial['refresh_time'] if isinstance(self.blackbox, BlackboxFunction) and self.blackbox.callback_func is not None: self.blackbox.callback_func(**cbd) if self._visdom_viewer is not None: self._visdom_viewer.update(cbd) return loss def run(self, print_stats=True): """ This function starts the optimization process. :param print_stats: [bool] en- or disable console output """ self._idx = 0 self.trials = Trials() if self._has_maxiteration_field: if 'solver_max_iterations' not in self.project.__dict__: - msg = "Missing max_iteration entry in project, use default {}!".format(DEFAULTITERATIONS) - LOG.warning(msg) - print("WARNING: {}".format(msg)) - setattr(self.project, 'solver_max_iterations', DEFAULTITERATIONS) + raise Exception("Missing max_iterations parameter which is essential for this type of solver!") self._max_iterations = self.project.solver_max_iterations start_time = datetime.datetime.now() try: search_space = self.convert_searchspace(self.project.hyperparameter) except Exception as e: msg = "Failed to convert searchspace, error: {}".format(e) LOG.error(msg) raise AssertionError(msg) try: self.execute_solver(search_space) except Exception as e: msg = "Failed to execute solver, error: {}".format(e) LOG.error(msg) raise AssertionError(msg) end_time = datetime.datetime.now() dt = end_time - start_time days = divmod(dt.total_seconds(), 86400) hours = divmod(days[1], 3600) minutes = divmod(hours[1], 60) seconds = divmod(minutes[1], 1) milliseconds = divmod(seconds[1], 0.001) self._total_duration = [int(days[0]), int(hours[0]), int(minutes[0]), int(seconds[0]), int(milliseconds[0])] if print_stats: self.print_best() self.print_timestats() def get_results(self): """ This function returns a complete optimization history as pandas DataFrame and a dict with the optimal parameter set. :return: [DataFrame], [dict] history and optimal parameter set """ assert isinstance(self.trials, Trials), "precondition violation, wrong trials type! Maybe solver was not yet executed?" results = {'duration': [], 'losses': [], 'status': []} pset = self.trials.trials[0]['misc']['vals'] for p in pset.keys(): results[p] = [] for n, trial in enumerate(self.trials.trials): t1 = trial['book_time'] t2 = trial['refresh_time'] results['duration'].append((t2 - t1).microseconds / 1000.0) results['losses'].append(trial['result']['loss']) results['status'].append(trial['result']['status'] == 'ok') losses = np.array(results['losses']) results['losses'] = list(losses) pset = trial['misc']['vals'] for p in pset.items(): results[p[0]].append(p[1][0]) return pd.DataFrame.from_dict(results), self.best def print_best(self): print("\n") print("#" * 40) print("### Best Parameter Choice ###") print("#" * 40) for name, value in self.best.items(): print(" - {}\t:\t{}".format(name, value)) print("\n - number of iterations\t:\t{}".format(self.trials.trials[-1]['tid']+1)) print(" - total time\t:\t{}d:{}h:{}m:{}s:{}ms".format(self._total_duration[0], self._total_duration[1], self._total_duration[2], self._total_duration[3], self._total_duration[4])) print("#" * 40) def compute_time_statistics(self): dts = [] for trial in self._trials.trials: if 'book_time' in trial.keys() and 'refresh_time' in trial.keys(): dt = trial['refresh_time'] - trial['book_time'] dts.append(dt.total_seconds()) self._time_per_iteration = np.mean(dts) * 1e3 self._accumulated_blackbox_time = np.sum(dts) * 1e3 tmp = self.total_duration - self._accumulated_blackbox_time self._solver_overhead = int(np.round(100.0 / (self.total_duration+1e-12) * tmp)) def print_timestats(self): print("\n") print("#" * 40) print("### Timing Statistics ###") print("#" * 40) print(" - per iteration: {}ms".format(int(self.time_per_iteration*1e4)/10000)) print(" - total time: {}d:{}h:{}m:{}s:{}ms".format(self._total_duration[0], self._total_duration[1], self._total_duration[2], self._total_duration[3], self._total_duration[4])) print("#" * 40) print(" - solver overhead: {}%".format(self.solver_overhead)) def start_viewer(self, port=8097, server="http://localhost"): try: self._visdom_viewer = VisdomViewer(self._project, port, server) except Exception as e: import warnings warnings.warn("Failed starting VisdomViewer. Is the server running? If not start it via $visdom") LOG.error("Failed starting VisdomViewer: {}".format(e)) self._visdom_viewer = None @property def project(self): return self._project @project.setter def project(self, value): if not isinstance(value, HyppopyProject): msg = "Input error, project_manager of type: {} not allowed!".format(type(value)) LOG.error(msg) raise IOError(msg) self._project = value @property def blackbox(self): return self._blackbox @blackbox.setter def blackbox(self, value): if isinstance(value, types.FunctionType) or isinstance(value, BlackboxFunction) or isinstance(value, VirtualFunction): self._blackbox = value else: self._blackbox = None msg = "Input error, blackbox of type: {} not allowed!".format(type(value)) LOG.error(msg) raise IOError(msg) @property def best(self): return self._best @best.setter def best(self, value): if not isinstance(value, dict): msg = "Input error, best of type: {} not allowed!".format(type(value)) LOG.error(msg) raise IOError(msg) self._best = value @property def trials(self): return self._trials @trials.setter def trials(self, value): self._trials = value @property def max_iterations(self): return self._max_iterations @max_iterations.setter def max_iterations(self, value): if not isinstance(value, int): msg = "Input error, max_iterations of type: {} not allowed!".format(type(value)) LOG.error(msg) raise IOError(msg) if value < 1: msg = "Precondition violation, max_iterations < 1!" LOG.error(msg) raise IOError(msg) self._max_iterations = value @property def total_duration(self): return (self._total_duration[0]*86400 + self._total_duration[1] * 3600 + self._total_duration[2] * 60 + self._total_duration[3]) * 1000 + self._total_duration[4] @property def solver_overhead(self): if self._solver_overhead is None: self.compute_time_statistics() return self._solver_overhead @property def time_per_iteration(self): if self._time_per_iteration is None: self.compute_time_statistics() return self._time_per_iteration @property def accumulated_blackbox_time(self): if self._accumulated_blackbox_time is None: self.compute_time_statistics() - return self._accumulated_blackbox_time + return self._accumulated_blackbox_time \ No newline at end of file diff --git a/hyppopy/tests/test_quasirandomsearchsolver.py b/hyppopy/tests/test_quasirandomsearchsolver.py index 54b2c26..76b2c47 100644 --- a/hyppopy/tests/test_quasirandomsearchsolver.py +++ b/hyppopy/tests/test_quasirandomsearchsolver.py @@ -1,160 +1,160 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical Image Computing. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE import unittest import matplotlib.pylab as plt from hyppopy.solvers.QuasiRandomsearchSolver import * from hyppopy.VirtualFunction import VirtualFunction from hyppopy.HyppopyProject import HyppopyProject class QuasiRandomsearchTestSuite(unittest.TestCase): def setUp(self): pass def test_get_gaussian_ranges(self): interval = [0, 10] N = 10 ranges = get_gaussian_ranges(interval[0], interval[1], N) - gt = [[0, 1.97368411013644], - [1.97368411013644, 3.1010703630207566], - [3.1010703630207566, 3.856779967954119], - [3.856779967954119, 4.4512421980703], - [4.4512421980703, 5.000000000000001], - [5.000000000000001, 5.5487578019297015], - [5.5487578019297015, 6.143220032045882], - [6.143220032045882, 6.898929636979244], - [6.898929636979244, 8.026315889863561], - [8.026315889863561, 10.0]] + gt = [[0, 2.592443381276233], + [2.592443381276233, 3.673134565097225], + [3.673134565097225, 4.251586871937128], + [4.251586871937128, 4.6491509407201], + [4.6491509407201, 5.000000000000001], + [5.000000000000001, 5.350849059279902], + [5.350849059279902, 5.748413128062873], + [5.748413128062873, 6.326865434902777], + [6.326865434902777, 7.407556618723769], + [7.407556618723769, 10.000000000000002]] for a, b in zip(ranges, gt): self.assertAlmostEqual(a[0], b[0]) self.assertAlmostEqual(a[1], b[1]) interval = [-100, 100] N = 10 ranges = get_gaussian_ranges(interval[0], interval[1], N) - gt = [[-100, -60.526317797271204], - [-60.526317797271204, -37.97859273958487], - [-37.97859273958487, -22.864400640917623], - [-22.864400640917623, -10.975156038594006], - [-10.975156038594006, 0.0], - [0.0, 10.975156038594006], - [10.975156038594006, 22.864400640917623], - [22.864400640917623, 37.97859273958487], - [37.97859273958487, 60.526317797271204], - [60.526317797271204, 100.0]] + gt = [[-100, -48.151132374475345], + [-48.151132374475345, -26.537308698055508], + [-26.537308698055508, -14.96826256125745], + [-14.96826256125745, -7.0169811855980315], + [-7.0169811855980315, -1.2434497875801753e-14], + [-1.2434497875801753e-14, 7.016981185598007], + [7.016981185598007, 14.968262561257426], + [14.968262561257426, 26.537308698055483], + [26.537308698055483, 48.151132374475324], + [48.151132374475324, 99.99999999999997]] for a, b in zip(ranges, gt): self.assertAlmostEqual(a[0], b[0]) self.assertAlmostEqual(a[1], b[1]) def test_get_loguniform_ranges(self): interval = [1, 1000] N = 10 ranges = get_loguniform_ranges(interval[0], interval[1], N) gt = [[1.0, 1.9952623149688797], [1.9952623149688797, 3.9810717055349727], [3.9810717055349727, 7.943282347242818], [7.943282347242818, 15.848931924611136], [15.848931924611136, 31.62277660168379], [31.62277660168379, 63.095734448019364], [63.095734448019364, 125.89254117941677], [125.89254117941677, 251.18864315095806], [251.18864315095806, 501.1872336272723], [501.1872336272723, 999.9999999999998]] for a, b in zip(ranges, gt): self.assertAlmostEqual(a[0], b[0]) self.assertAlmostEqual(a[1], b[1]) interval = [1, 10000] N = 50 ranges = get_loguniform_ranges(interval[0], interval[1], N) gt = [[1.0, 1.202264434617413], [1.202264434617413, 1.4454397707459274], [1.4454397707459274, 1.7378008287493756], [1.7378008287493756, 2.0892961308540396], [2.0892961308540396, 2.51188643150958], [2.51188643150958, 3.0199517204020165], [3.0199517204020165, 3.6307805477010135], [3.6307805477010135, 4.36515832240166], [4.36515832240166, 5.248074602497727], [5.248074602497727, 6.309573444801933], [6.309573444801933, 7.5857757502918375], [7.5857757502918375, 9.120108393559098], [9.120108393559098, 10.964781961431854], [10.964781961431854, 13.18256738556407], [13.18256738556407, 15.848931924611136], [15.848931924611136, 19.054607179632477], [19.054607179632477, 22.908676527677738], [22.908676527677738, 27.542287033381676], [27.542287033381676, 33.11311214825911], [33.11311214825911, 39.810717055349734], [39.810717055349734, 47.863009232263856], [47.863009232263856, 57.543993733715695], [57.543993733715695, 69.18309709189366], [69.18309709189366, 83.17637711026713], [83.17637711026713, 100.00000000000004], [100.00000000000004, 120.22644346174135], [120.22644346174135, 144.54397707459285], [144.54397707459285, 173.78008287493753], [173.78008287493753, 208.92961308540396], [208.92961308540396, 251.18864315095806], [251.18864315095806, 301.9951720402017], [301.9951720402017, 363.0780547701015], [363.0780547701015, 436.5158322401662], [436.5158322401662, 524.8074602497729], [524.8074602497729, 630.9573444801938], [630.9573444801938, 758.5775750291845], [758.5775750291845, 912.0108393559099], [912.0108393559099, 1096.4781961431854], [1096.4781961431854, 1318.2567385564075], [1318.2567385564075, 1584.8931924611143], [1584.8931924611143, 1905.4607179632485], [1905.4607179632485, 2290.867652767775], [2290.867652767775, 2754.228703338169], [2754.228703338169, 3311.3112148259115], [3311.3112148259115, 3981.071705534977], [3981.071705534977, 4786.300923226385], [4786.300923226385, 5754.399373371577], [5754.399373371577, 6918.309709189369], [6918.309709189369, 8317.63771102671], [8317.63771102671, 10000.00000000001]] for a, b in zip(ranges, gt): self.assertAlmostEqual(a[0], b[0]) self.assertAlmostEqual(a[1], b[1]) def test_QuasiRandomSampleGenerator(self): N_samples = 10*10*10 axis_data = {"p1": {"domain": "loguniform", "data": [1, 10000], "type": "float"}, "p2": {"domain": "normal", "data": [-5, 5], "type": "float"}, "p3": {"domain": "uniform", "data": [0, 10], "type": "float"}, "p4": {"domain": "categorical", "data": [False, True], "type": "bool"}} sampler = QuasiRandomSampleGenerator(N_samples, 0.1) for name, axis in axis_data.items(): sampler.set_axis(name, axis["data"], axis["domain"], axis["type"]) for i in range(N_samples): sample = sampler.next() self.assertTrue(len(sample.keys()) == 4) for k in range(4): self.assertTrue("p{}".format(k+1) in sample.keys()) self.assertTrue(1 <= sample["p1"] <= 10000) self.assertTrue(-5 <= sample["p2"] <= 5) self.assertTrue(0 <= sample["p3"] <= 10) self.assertTrue(isinstance(sample["p4"], bool)) self.assertTrue(sampler.next() is None) if __name__ == '__main__': unittest.main()