未验证 提交 f84b54eb 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto parallel] Redesign the tuner for auto parallel (#40121)

* [Auto Parallel] Redesign the tunner for Auto Parallel
上级 0c333543
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
class Storable(object):
def get_state(self):
raise NotImplementedError
def set_state(self, state):
raise NotImplementedError
def save(self, path):
state = self.get_state()
state_json = json.dumps(state)
with open(path, "w") as f:
f.write(state_json)
return str(path)
def load(self, path):
with open(path, "r") as f:
state_data = f.read()
state = json.loads(state_data)
self.set_state(state)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import contextlib
import copy
import math
import random
import numpy as np
from .tunable_variable import Boolean
from .tunable_variable import Fixed
from .tunable_variable import Choice
from .tunable_variable import IntRange
from .tunable_variable import FloatRange
class TunableSpace(object):
"""
A TunableSpace is constructed by the tunable variables.
"""
def __init__(self):
# Tunable variables for this tunable variables
self._variables = {}
# Specific values coresponding to each tunable variable
self._values = {}
@property
def variables(self):
return self._variables
@property
def values(self):
return self._values
def get_value(self, name):
if name in self.values:
return self.values[name]
else:
raise KeyError("{} does not exist.".format(name))
def set_value(self, name, value):
if name in self.values:
self.values[name] = value
else:
raise KeyError("{} does not exist.".format(name))
def _exists(self, name):
if name in self._variables:
return True
return False
def _retrieve(self, tv):
tv = tv.__class__.from_state(tv.get_state())
if self._exists(tv.name):
return self.get_value(tv.name)
return self._register(tv)
def _register(self, tv):
self._variables[tv.name] = tv
if tv.name not in self.values:
self.values[tv.name] = tv.default
return self.values[tv.name]
def __getitem__(self, name):
return self.get_value(name)
def __setitem__(self, name, value):
self.set_value(name, value)
def __contains__(self, name):
try:
self.get_value(name)
return True
except (KeyError, ValueError):
return False
def fixed(self, name, default):
tv = Fixed(name=name, default=default)
return self._retrieve(tv)
def boolean(self, name, default=False):
tv = Boolean(name=name, default=default)
return self._retrieve(tv)
def choice(self, name, values, default=None):
tv = Choice(name=name, values=values, default=default)
return self._retrieve(tv)
def int_range(self, name, start, stop, step=1, default=None):
tv = IntRange(
name=name, start=start, stop=stop, step=step, default=default)
return self._retrieve(tv)
def float_range(self, name, start, stop, step=None, default=None):
tv = FloatRange(
name=name, start=start, stop=stop, step=step, default=default)
return self._retrieve(tv)
def get_state(self):
return {
"variables": [{
"class_name": v.__class__.__name__,
"state": v.get_state()
} for v in self._variables.values()],
"values": dict((k, v) for (k, v) in self.values.items())
}
@classmethod
def from_state(cls, state):
ts = cls()
for v in state["variables"]:
v = _deserialize_tunable_variable(v)
ts._variables[v.name] = v
ts._values = dict((k, v) for (k, v) in state["values"].items())
return ts
def _deserialize_tunable_variable(state):
classes = (Boolean, Fixed, Choice, IntRange, FloatRange)
cls_name_to_cls = {cls.__name__: cls for cls in classes}
if isinstance(state, classes):
return state
if (not isinstance(state, dict) or "class_name" not in state or
"state" not in state):
raise ValueError(
"Expect state to be a python dict containing class_name and state as keys, but found {}"
.format(state))
cls_name = state["class_name"]
cls = cls_name_to_cls[cls_name]
if cls is None:
raise ValueError("Unknown class name {}".format(cls_name))
cls_state = state["state"]
deserialized_object = cls.from_state(cls_state)
return deserialized_object
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
class TunableVariable(object):
"""
Tunablevariable base class.
"""
def __init__(self, name, default=None):
self.name = name
self._default = default
@property
def default(self):
return self._default
def get_state(self):
return {"name": self.name, "default": self.default}
@classmethod
def from_state(cls, state):
return cls(**state)
class Fixed(TunableVariable):
"""
Fixed variable which cannot be changed.
"""
def __init__(self, name, default):
super(Fixed, self).__init__(name=name, default=default)
self.name = name
if not isinstance(default, (str, int, float, bool)):
raise ValueError(
"Fixed must be an str, int, float or bool, but found {}"
.format(default))
self._default = default
def random(self, seed=None):
return self._default
def __repr__(self):
return "Fixed(name: {}, value: {})".format(self.name, self.default)
class Boolean(TunableVariable):
"""
Choice between True and False.
"""
def __init__(self, name, default=False):
super(Boolean, self).__init__(name=name, default=default)
if default not in {True, False}:
raise ValueError(
"default must be a Python boolean, but got {}".format(default))
def random(self, seed=None):
rng = np.random.default_rng(seed)
return rng.choice((True, False))
def __repr__(self):
return 'Boolean(name: "{}", default: {})'.format(self.name,
self.default)
class Choice(TunableVariable):
def __init__(self, name, values, default=None):
super(Choice, self).__init__(name=name, default=default)
types = set(type(v) for v in values)
if len(types) > 1:
raise TypeError(
"Choice can contain only one type of value, but found values: {} with types: {}."
.format(str(values), str(types)))
if isinstance(values[0], str):
values = [str(v) for v in values]
if default is not None:
default = str(default)
elif isinstance(values[0], int):
values = [int(v) for v in values]
if default is not None:
default = int(default)
elif isinstance(values[0], float):
values = [float(v) for v in values]
if default is not None:
default = float(default)
elif isinstance(values[0], bool):
values = [bool(v) for v in values]
if default is not None:
default = bool(default)
else:
raise TypeError(
"Choice can only contain str, int, float, or boll, but found: {} "
.format(str(values)))
self.values = values
if default is not None and default not in values:
raise ValueError(
"The default value should be one of the choices {}, but found {}".
format(values, default))
self._default = default
@property
def default(self):
if self._default is None:
if None in self.values:
return None
return self.values[0]
return self._default
def random(self, seed=None):
rng = np.random.default_rng(seed)
return rng.choice(self.values)
def get_state(self):
state = super(Choice, self).get_state()
state["values"] = self.values
return state
def __repr__(self):
return 'Choice(name: "{}", values: {}, default: {})'.format(
self.name, self.values, self.default)
class IntRange(TunableVariable):
"""
Integer range.
"""
def __init__(self, name, start, stop, step=1, default=None, endpoint=False):
super(IntRange, self).__init__(name=name, default=default)
self.start = self._check_int(start)
self.stop = self._check_int(stop)
self.step = self._check_int(step)
self._default = default
self.endpoint = endpoint
@property
def default(self):
if self._default is not None:
return self._default
return self.start
def random(self, seed=None):
rng = np.random.default_rng(seed)
value = (self.stop - self.start) * rng.random() + self.start
if self.step is not None:
if self.endpoint:
values = np.arange(self.start, self.stop + 1e-7, step=self.step)
else:
values = np.arange(self.start, self.stop, step=self.step)
closest_index = np.abs(values - value).argmin()
value = values[closest_index]
return int(value)
def get_state(self):
state = super(IntRange, self).get_state()
state["start"] = self.start
state["stop"] = self.stop
state["step"] = self.step
state["default"] = self._default
return state
def _check_int(self, val):
int_val = int(val)
if int_val != val:
raise ValueError("Expects val is an int, but found: {}.".format(
str(val)))
return int_val
def __repr__(self):
return "IntRange(name: {}, start: {}, stop: {}, step: {}, default: {})".format(
self.name, self.start, self.stop, self.step, self.default)
class FloatRange(TunableVariable):
"""
Float range.
"""
def __init__(self,
name,
start,
stop,
step=None,
default=None,
endpoint=False):
super(FloatRange, self).__init__(name=name, default=default)
self.stop = float(stop)
self.start = float(start)
if step is not None:
self.step = float(step)
else:
self.step = None
self._default = default
self.endpoint = endpoint
@property
def default(self):
if self._default is not None:
return self._default
return self.start
def random(self, seed=None):
rng = np.random.default_rng(seed)
value = (self.stop - self.start) * rng.random() + self.start
if self.step is not None:
if self.endpoint:
values = np.arange(self.start, self.stop + 1e-7, step=self.step)
else:
values = np.arange(self.start, self.stop, step=self.step)
closest_index = np.abs(values - value).argmin()
value = values[closest_index]
return value
def get_state(self):
state = super(FloatRange, self).get_state()
state["start"] = self.start
state["stop"] = self.stop
state["step"] = self.step
state["endpoint"] = self.endpoint
return state
def __repr__(self):
return "FloatRange(name: {}, start: {}, stop: {}, step: {}, default: {}, endpoint: {})".format(
self.name, self.start, self.stop, self.step, self.default,
self.endpoint)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.distributed.auto_parallel.tuner import tunable_space as ts
class TestTunableSpace(unittest.TestCase):
def test_fixed(self):
space = ts.TunableSpace()
fixed = space.fixed("fixed", default=4)
self.assertEqual(space.values["fixed"], 4)
self.assertEqual(len(space.variables), 1)
self.assertEqual(space.variables["fixed"].name, "fixed")
space.values["fixed"] = 2
self.assertEqual(space.get_value("fixed"), 2)
self.assertEqual(space.values, {"fixed": 2})
self.assertEqual(len(space.variables), 1)
self.assertEqual(space.variables["fixed"].name, "fixed")
def test_boolean(self):
space = ts.TunableSpace()
boolean = space.boolean("boolean")
self.assertEqual(space.values["boolean"], False)
self.assertEqual(len(space.variables), 1)
self.assertEqual(space.variables["boolean"].name, "boolean")
space.values["boolean"] = True
self.assertEqual(space.get_value("boolean"), True)
self.assertEqual(space.values, {"boolean": True})
self.assertEqual(len(space.variables), 1)
self.assertEqual(space.variables["boolean"].name, "boolean")
def test_choice(self):
space = ts.TunableSpace()
choice = space.choice("choice", [1, 2, 3, 4], default=4)
self.assertEqual(space.values["choice"], 4)
self.assertEqual(len(space.variables), 1)
self.assertEqual(space.variables["choice"].name, "choice")
space.values["choice"] = 2
self.assertEqual(space.get_value("choice"), 2)
self.assertEqual(space.values, {"choice": 2})
self.assertEqual(len(space.variables), 1)
self.assertEqual(space.variables["choice"].name, "choice")
def test_int_range(self):
space = ts.TunableSpace()
int_range = space.int_range("int_range", start=1, stop=4, default=2)
self.assertEqual(space.values["int_range"], 2)
self.assertEqual(len(space.variables), 1)
self.assertEqual(space.variables["int_range"].name, "int_range")
space.values["int_range"] = 3
self.assertEqual(space.get_value("int_range"), 3)
self.assertEqual(space.values, {"int_range": 3})
self.assertEqual(len(space.variables), 1)
self.assertEqual(space.variables["int_range"].name, "int_range")
def test_float_range(self):
space = ts.TunableSpace()
float_range = space.float_range(
"float_range", start=0.4, stop=4.4, default=2.0)
self.assertEqual(space.values["float_range"], 2.0)
self.assertEqual(len(space.variables), 1)
self.assertEqual(space.variables["float_range"].name, "float_range")
space.values["float_range"] = 3.0
self.assertEqual(space.get_value("float_range"), 3.0)
self.assertEqual(space.values, {"float_range": 3.0})
self.assertEqual(len(space.variables), 1)
self.assertEqual(space.variables["float_range"].name, "float_range")
def test_varibles(self):
space = ts.TunableSpace()
choice = space.choice("choice", [1, 2, 3, 4], default=4)
self.assertEqual(space.values["choice"], 4)
self.assertEqual(len(space.variables), 1)
self.assertEqual(space.variables["choice"].name, "choice")
int_range = space.int_range("int_range", start=1, stop=4, default=2)
self.assertEqual(space.values["int_range"], 2)
self.assertEqual(len(space.variables), 2)
self.assertEqual(space.variables["int_range"].name, "int_range")
def test_not_populated_variable(self):
space = ts.TunableSpace()
choice = space.choice("choice", [1, 2, 3, 4], default=2)
self.assertEqual(choice, 2)
def test_populated_variable(self):
space = ts.TunableSpace()
space.values["choice"] = 2
choice = space.choice("choice", [1, 2, 3, 4], default=4)
self.assertEqual(choice, 2)
space["choice"] = 3
self.assertNotEqual(space.values["choice"], 2)
self.assertEqual(space.values["choice"], 3)
def test_state(self):
space = ts.TunableSpace()
choice = space.choice("choice", [1, 2, 3, 4], default=4)
int_range = space.int_range("int_range", start=1, stop=4, default=2)
new_space = space.from_state(space.get_state())
self.assertEqual(new_space.get_value("choice"), 4)
self.assertEqual(new_space.get_value("int_range"), 2)
self.assertEqual(len(new_space.variables), 2)
self.assertEqual(len(new_space.values), 2)
self.assertEqual(new_space.variables["choice"].name, "choice")
self.assertEqual(new_space.variables["choice"].default, 4)
self.assertEqual(new_space.variables["choice"].values, [1, 2, 3, 4])
self.assertEqual(new_space.variables["int_range"].name, "int_range")
self.assertEqual(new_space.variables["int_range"].default, 2)
self.assertEqual(new_space.variables["int_range"].start, 1)
self.assertEqual(new_space.variables["int_range"].stop, 4)
self.assertEqual(new_space.variables["int_range"].step, 1)
self.assertEqual(new_space.variables["int_range"].endpoint, False)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.distributed.auto_parallel.tuner import tunable_variable as tv
class TestTunableVariable(unittest.TestCase):
def test_fixed(self):
fixed = tv.Fixed("fixed", True)
fixed = tv.Fixed.from_state(fixed.get_state())
self.assertEqual(fixed.default, True)
self.assertEqual(fixed.random(), True)
fixed = tv.Fixed("fixed", 1)
fixed = tv.Fixed.from_state(fixed.get_state())
self.assertEqual(fixed.default, 1)
self.assertEqual(fixed.random(), 1)
def test_boolean(self):
boolean = tv.Boolean("bool")
boolean = tv.Boolean.from_state(boolean.get_state())
self.assertEqual(boolean.default, False)
self.assertIn(boolean.random(), [True, False])
self.assertIn(boolean.random(1234), [True, False])
boolean = tv.Boolean("bool", True)
boolean = tv.Boolean.from_state(boolean.get_state())
self.assertEqual(boolean.default, True)
self.assertIn(boolean.random(), [True, False])
self.assertIn(boolean.random(1234), [True, False])
def test_choice(self):
choice = tv.Choice("choice", [1, 2, 3, 4])
choice = tv.Choice.from_state(choice.get_state())
self.assertEqual(choice.default, 1)
self.assertIn(choice.random(), [1, 2, 3, 4])
self.assertIn(choice.random(1234), [1, 2, 3, 4])
choice = tv.Choice("choice", [1, 2, 3, 4], default=2)
choice = tv.Choice.from_state(choice.get_state())
self.assertEqual(choice.default, 2)
self.assertIn(choice.random(), [1, 2, 3, 4])
self.assertIn(choice.random(1234), [1, 2, 3, 4])
def test_int_range(self):
int_range = tv.IntRange("int_range", start=1, stop=4, default=2)
int_range = tv.IntRange.from_state(int_range.get_state())
self.assertEqual(int_range.default, 2)
self.assertIn(int_range.random(), [1, 2, 3, 4])
self.assertIn(int_range.random(1234), [1, 2, 3, 4])
self.assertNotEqual(int_range.default, 4)
int_range = tv.IntRange(
"int_range", start=1, stop=8, step=2, default=3, endpoint=True)
int_range = tv.IntRange.from_state(int_range.get_state())
self.assertEqual(int_range.default, 3)
self.assertIn(int_range.random(), [1, 3, 5, 7])
self.assertIn(int_range.random(1234), [1, 3, 5, 7])
self.assertNotEqual(int_range.default, 2)
def test_float_range(self):
float_range = tv.FloatRange(
"float_range", start=0.4, stop=4.4, default=2.0)
float_range = tv.FloatRange.from_state(float_range.get_state())
self.assertEqual(float_range.default, 2.0)
self.assertGreater(float_range.random(), 0.4)
self.assertLess(float_range.random(1234), 4.4)
self.assertNotAlmostEqual(float_range.random(), 1)
self.assertNotAlmostEqual(float_range.random(), 4.4)
float_range = tv.FloatRange(
"float_range",
start=0.4,
stop=8.4,
step=2.0,
default=3.0,
endpoint=True)
float_range = tv.FloatRange.from_state(float_range.get_state())
self.assertEqual(float_range.default, 3.0)
self.assertGreater(float_range.random(), 0.4)
self.assertLessEqual(float_range.random(1234), 8.4)
self.assertNotAlmostEqual(float_range.random(), 2)
if __name__ == "__main__":
unittest.main()
......@@ -300,6 +300,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_parallel.parallel_layers',
'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel.operators',
'paddle.distributed.auto_parallel.tuner',
'paddle.distributed.passes',
'paddle.framework',
'paddle.jit',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册