"""
..
Copyright (c) 2018 LG Electronics Inc.
SPDX-License-Identifier: GPL-3.0-or-later
aup.Proposer.AbstractProposer
=============================
:mod:`aup.Proposer.AbstractProposer` provide interface for Hyperparameter Optimization Modules.
APIs
----
"""
import abc
import logging
import pickle
import json
import threading
from six.moves import input
from ..utils import set_default_keyvalue, check_missing_key, get_from_options
from . import ProposerStatus
ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()})
logger = logging.getLogger(__name__)
[docs]def create_param_config(name, vrange, vtype):
"""
Reads the configuration file and checks for errors.
"""
if vtype not in ('float', 'int', 'choice'):
raise ValueError("%s is not supported as hyperparameter type" % vtype)
if vtype != 'choice' and len(vrange) != 2:
raise ValueError("Range need to be a two element tuple for %s" % vtype)
return {'name': name, 'range': vrange, 'type': vtype}
[docs]class AbstractProposer(ABC):
"""
Proposer to generate new values for hyperparameters
:param config: experiment configuration
:type config: BasicConfig
"""
def __init__(self, config):
self.nSamples = 0 # number of total jobs for an experiment
self.counter = 0 # number of executed jobs
self.current_proposal = dict()
self.status = ProposerStatus.RUNNING # whether the experiment is finished
self.status_lock = threading.Lock()
AbstractProposer.verify_config(self, config)
[docs] def set_status(self, status):
with self.status_lock:
self.status = status
[docs] def get_status(self):
with self.status_lock:
return self.status
[docs] def increment_job_counter(self):
with self.status_lock:
self.counter += 1
[docs] def check_termination(self):
with self.status_lock:
if self.counter >= self.nSamples:
self.status = ProposerStatus.FINISHED
[docs] def get_remaining_jobs(self):
with self.status_lock:
return self.nSamples - self.counter
[docs] @staticmethod
def setup_config(): # pragma: no cover
config = []
try:
print("start adding hyperparameters, use 'stop' for variable name or ctrl+c to exit")
while True:
name = input("variable name:")
if name == "stop":
break
try:
res = input("range (separated by ,):")
if "'" in res:
res = res.replace("'", '"')
if res[0] == '[' and res[-1] == ']':
vrange = json.loads(res)
else:
vrange = json.loads("[" + res + "]")
except ValueError:
logger.critical("failed to parse range, treat it as strings separated by ','")
vrange = res.split(",")
if len(vrange) == 0:
raise ValueError("range needs at least one element")
vtype = get_from_options("type:", ("choice", "float", "int"))
config.append({'name': name, "range": vrange, "type": vtype})
except KeyboardInterrupt:
print("Config interrupted, completed variables are saved.")
return {"parameter_config": config}
[docs] @staticmethod
def parse_param_config(config):
"""
Parse the given experiment configuration of ``parameter_config``
If values are missing, fill in defaults.
:param config: config["param_config"]
:type config: dict
:return: updated config
:rtype: dict
"""
check_missing_key(config, "name",
"Missing name of the parameter, need to be consistent with your training code",
log=logger)
set_default_keyvalue("type", "int", config, log=logger)
set_default_keyvalue("range", [0, 1], config, log=logger)
return config
[docs] def get(self, **kwargs):
"""
Wrapper for specific :func:`get_param` to update ``current_proposal`` and ``counter``.
:param kwargs: any arguments to be passed to :func:`get_param`
:type kwargs: dict
:return: parameter values
:rtype: dict
"""
self.check_termination()
if self.get_status() != ProposerStatus.RUNNING:
return None
self.current_proposal = self.get_param(**kwargs)
logger.debug(self.current_proposal)
if not self.current_proposal:
return None
return self.current_proposal.copy()
[docs] @abc.abstractmethod
def get_param(self, **kwargs):
"""
Get new proposed parameter values
"""
raise NotImplementedError
[docs] def reload(self, path):
"""
Reload Proposer state from path
:param path: path to reload
:type path: str
"""
logger.info("Reload %s, previous cancelled job won't be run", path)
with open(path, 'rb') as f:
d = pickle.load(f)
for i in d.__dict__:
self.__dict__[i] = d.__dict__[i]
return self
[docs] def reset(self):
"""
Reset proposer
"""
with self.status_lock:
self.counter = 0
self.status = ProposerStatus.RUNNING
[docs] def save(self, path):
"""
Save Proposer state to path.
**Some proposer can not generate new parameters after saving.**
:param path: path to save
:type path: str
"""
with open(path, 'wb') as f:
pickle.dump(self, f)
[docs] def update(self, score, job):
"""
Update scores in proposer history
:param score: score returned by Job
:type score: float
:param job: Finished job
:type job: Job
"""
logger.debug("Get score ({}) for job {}".format(score, job.jid))
[docs] def failed(self, job):
"""
Mark job as failed in proposer history.
:param job: Failed job
:type job: Job
"""
logger.debug("Job {} marked as failed".format(job.jid))
[docs] def verify_config(self, config):
"""
Verify the input configuration is enough for the proposer
:param config: Experiment configuration of ``parameter_config``
:type config: dict
:return: config
:rtype: dict
"""
check_missing_key(config, "parameter_config",
"Specify the parameter configuration `parameter_config` to be searched", log=logger)
for i in config["parameter_config"]:
check_missing_key(i, "name", "hyperparameter name is missing", log=logger)
return config
def __getstate__(self):
state = self.__dict__.copy()
del state['status_lock']
return state
def __setstate__(self, state):
self.__dict__.update(state)
self.status_lock = threading.Lock()