tunable_space.py 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
# Notice that the following codes are modified from KerasTuner to implement our own tuner.
16 17
# Please refer to https://github.com/keras-team/keras-tuner/blob/master/keras_tuner/engine/hyperparameters.py.

18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
from .tunable_variable import Boolean
from .tunable_variable import Fixed
from .tunable_variable import Choice
from .tunable_variable import IntRange
from .tunable_variable import FloatRange


class TunableSpace(object):
    """
    A TunableSpace is constructed by the tunable variables.
    """

    def __init__(self):
        # Tunable variables for this tunable variables
        self._variables = {}
        # Specific values coresponding to each tunable variable
        self._values = {}

    @property
    def variables(self):
        return self._variables

    @property
    def values(self):
        return self._values

    def get_value(self, name):
        if name in self.values:
            return self.values[name]
        else:
            raise KeyError("{} does not exist.".format(name))

    def set_value(self, name, value):
        if name in self.values:
            self.values[name] = value
        else:
            raise KeyError("{} does not exist.".format(name))

    def _exists(self, name):
        if name in self._variables:
            return True
        return False

    def _retrieve(self, tv):
        tv = tv.__class__.from_state(tv.get_state())
        if self._exists(tv.name):
            return self.get_value(tv.name)
        return self._register(tv)

    def _register(self, tv):
        self._variables[tv.name] = tv
        if tv.name not in self.values:
            self.values[tv.name] = tv.default
        return self.values[tv.name]

    def __getitem__(self, name):
        return self.get_value(name)

    def __setitem__(self, name, value):
        self.set_value(name, value)

    def __contains__(self, name):
        try:
            self.get_value(name)
            return True
        except (KeyError, ValueError):
            return False

    def fixed(self, name, default):
        tv = Fixed(name=name, default=default)
        return self._retrieve(tv)

    def boolean(self, name, default=False):
        tv = Boolean(name=name, default=default)
        return self._retrieve(tv)

    def choice(self, name, values, default=None):
        tv = Choice(name=name, values=values, default=default)
        return self._retrieve(tv)

    def int_range(self, name, start, stop, step=1, default=None):
99 100 101 102 103
        tv = IntRange(name=name,
                      start=start,
                      stop=stop,
                      step=step,
                      default=default)
104 105 106
        return self._retrieve(tv)

    def float_range(self, name, start, stop, step=None, default=None):
107 108 109 110 111
        tv = FloatRange(name=name,
                        start=start,
                        stop=stop,
                        step=step,
                        default=default)
112 113 114 115 116 117 118 119
        return self._retrieve(tv)

    def get_state(self):
        return {
            "variables": [{
                "class_name": v.__class__.__name__,
                "state": v.get_state()
            } for v in self._variables.values()],
120 121
            "values":
            dict((k, v) for (k, v) in self.values.items())
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
        }

    @classmethod
    def from_state(cls, state):
        ts = cls()
        for v in state["variables"]:
            v = _deserialize_tunable_variable(v)
            ts._variables[v.name] = v
        ts._values = dict((k, v) for (k, v) in state["values"].items())
        return ts


def _deserialize_tunable_variable(state):
    classes = (Boolean, Fixed, Choice, IntRange, FloatRange)
    cls_name_to_cls = {cls.__name__: cls for cls in classes}

    if isinstance(state, classes):
        return state

141 142
    if (not isinstance(state, dict) or "class_name" not in state
            or "state" not in state):
143 144 145 146 147 148 149 150 151 152 153 154
        raise ValueError(
            "Expect state to be a python dict containing class_name and state as keys, but found {}"
            .format(state))

    cls_name = state["class_name"]
    cls = cls_name_to_cls[cls_name]
    if cls is None:
        raise ValueError("Unknown class name {}".format(cls_name))

    cls_state = state["state"]
    deserialized_object = cls.from_state(cls_state)
    return deserialized_object