Source code for aup.Proposer.RandomProposer

"""
..
  Copyright (c) 2018 LG Electronics Inc.
  SPDX-License-Identifier: GPL-3.0-or-later

aup.Proposer.RandomProposer
===========================

Random sampling of the parameters

Configuration
-------------

General parameters
~~~~~~~~~~~~~~~~~~

============ ============= ========================================
Name         Default value Explanation
============ ============= ========================================
proposer     -             random
n_samples    -             Total number of trials to sample
random_seed  0             [Optional] seed for random generator
============ ============= ========================================

Specific parameters for ``parameter_config``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

======= ==========================================================================
Name    Explanation
======= ==========================================================================
name    name of the variable, will be used in the job config, i.e. training code
type    type of the parameter to be sampled: choose from "float","int","choice"
range   range of the parameter.  For "choice", list all the feasible values
======= ==========================================================================

APIs
----
"""
import logging

from numpy import random
from six.moves import input

from .AbstractProposer import AbstractProposer
from ..utils import check_missing_key, set_default_keyvalue

logger = logging.getLogger(__name__)


def _random_int(x):
    if len(x) != 2:
        msg = ("Range of random integer should have two elements, got %d" % len(x))
        logger.fatal(msg)
        raise ValueError(msg)
    return lambda: random.randint(x[0], x[1] + 1)


def _random_float(x):
    if len(x) != 2:
        msg = ("Range of random float should have two elements, got %d" % len(x))
        logger.fatal(msg)
        raise ValueError(msg)
    return lambda: random.rand() * (x[1] - x[0]) + x[0]


def _random_choice(x):
    if len(x) < 1:
        msg = "Range of random choice should have some elements, got nothing"
        logger.fatal(msg)
        raise ValueError(msg)
    return lambda: x[random.choice(len(x))]


_random_fun = {
    'int': _random_int,
    'float': _random_float,
    'choice': _random_choice
}


[docs]class RandomProposer(AbstractProposer): """ Random proposer :param config: experiment configuration contains the details searching space :param random_seed: default random seed if not in config """ def __init__(self, config, random_seed=0): super(RandomProposer, self).__init__(config) self.verify_config(config) self.nSamples = config["n_samples"] set_default_keyvalue("random_seed", random_seed, config, log=logger) random.seed(config["random_seed"]) self.random_state = None # for suspend and resume self.params_gen = {} for param in config["parameter_config"]: p = self.parse_param_config(param) self.params_gen[p['name']] = _random_fun[p['type']](p['range'])
[docs] @staticmethod def setup_config(): # pragma: no cover """ Set up experiment configuration :return: experiment config in dict. """ config = dict() config['n_samples'] = int(input("number of model samples to draw randomly, `n_samples`, [1]:") or 1) config['random_seed'] = int(input("random seed, `random_seed`, [0]:") or 0) config.update(AbstractProposer.setup_config()) return config
[docs] def get_param(self, **kwargs): """ Get the next parameter set :return: parameter name and value pairs in dict """ if 'params_gen' not in self.__dict__: return None for i in self.params_gen: self.current_proposal[i] = self.params_gen[i]() logger.debug(self.current_proposal) return self.current_proposal
[docs] def reload(self, path): super(RandomProposer, self).reload(path) random.set_state(self.random_state) self.random_state = None return self
[docs] def save(self, path): if 'params_gen' in self.__dict__: del self.params_gen self.random_state = random.get_state() super(RandomProposer, self).save(path)
[docs] def verify_config(self, config): check_missing_key(config, "n_samples", "Specify number of samples to randomly draw", log=logger) return config