# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import contextlib import copy import math import random import numpy as np from .tunable_variable import Boolean from .tunable_variable import Fixed from .tunable_variable import Choice from .tunable_variable import IntRange from .tunable_variable import FloatRange class TunableSpace(object): """ A TunableSpace is constructed by the tunable variables. """ def __init__(self): # Tunable variables for this tunable variables self._variables = {} # Specific values coresponding to each tunable variable self._values = {} @property def variables(self): return self._variables @property def values(self): return self._values def get_value(self, name): if name in self.values: return self.values[name] else: raise KeyError("{} does not exist.".format(name)) def set_value(self, name, value): if name in self.values: self.values[name] = value else: raise KeyError("{} does not exist.".format(name)) def _exists(self, name): if name in self._variables: return True return False def _retrieve(self, tv): tv = tv.__class__.from_state(tv.get_state()) if self._exists(tv.name): return self.get_value(tv.name) return self._register(tv) def _register(self, tv): self._variables[tv.name] = tv if tv.name not in self.values: self.values[tv.name] = tv.default return self.values[tv.name] def __getitem__(self, name): return self.get_value(name) def __setitem__(self, name, value): self.set_value(name, value) def __contains__(self, name): try: self.get_value(name) return True except (KeyError, ValueError): return False def fixed(self, name, default): tv = Fixed(name=name, default=default) return self._retrieve(tv) def boolean(self, name, default=False): tv = Boolean(name=name, default=default) return self._retrieve(tv) def choice(self, name, values, default=None): tv = Choice(name=name, values=values, default=default) return self._retrieve(tv) def int_range(self, name, start, stop, step=1, default=None): tv = IntRange( name=name, start=start, stop=stop, step=step, default=default) return self._retrieve(tv) def float_range(self, name, start, stop, step=None, default=None): tv = FloatRange( name=name, start=start, stop=stop, step=step, default=default) return self._retrieve(tv) def get_state(self): return { "variables": [{ "class_name": v.__class__.__name__, "state": v.get_state() } for v in self._variables.values()], "values": dict((k, v) for (k, v) in self.values.items()) } @classmethod def from_state(cls, state): ts = cls() for v in state["variables"]: v = _deserialize_tunable_variable(v) ts._variables[v.name] = v ts._values = dict((k, v) for (k, v) in state["values"].items()) return ts def _deserialize_tunable_variable(state): classes = (Boolean, Fixed, Choice, IntRange, FloatRange) cls_name_to_cls = {cls.__name__: cls for cls in classes} if isinstance(state, classes): return state if (not isinstance(state, dict) or "class_name" not in state or "state" not in state): raise ValueError( "Expect state to be a python dict containing class_name and state as keys, but found {}" .format(state)) cls_name = state["class_name"] cls = cls_name_to_cls[cls_name] if cls is None: raise ValueError("Unknown class name {}".format(cls_name)) cls_state = state["state"] deserialized_object = cls.from_state(cls_state) return deserialized_object