From bd60757d03f062583b95c966e84036985c830f3b Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 24 Jul 2023 10:35:46 +0800 Subject: [PATCH] [AutoParallel] Add shard tensor and DistAttr api (#55494) * add shard tensor api * add DistAttr api * add unittest for coverage * fix process mesh sample code * fix checking error --- python/paddle/distributed/__init__.py | 8 +- .../paddle/distributed/auto_parallel/api.py | 126 ++++++++++++++++++ .../distributed/auto_parallel/process_mesh.py | 10 +- test/auto_parallel/CMakeLists.txt | 1 + test/auto_parallel/test_dist_tensor.py | 8 +- test/auto_parallel/test_shard_tensor_api.py | 78 +++++++++++ 6 files changed, 226 insertions(+), 5 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/api.py create mode 100644 test/auto_parallel/test_shard_tensor_api.py diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 8f6237bfa4c..183f307607c 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -61,8 +61,11 @@ from .communication import ( get_backend, ) # noqa: F401 +from .auto_parallel.process_mesh import ProcessMesh # noqa: F401 +from .auto_parallel.api import DistAttr # noqa: F401 + from .auto_parallel import shard_op # noqa: F401 -from .auto_parallel import shard_tensor # noqa: F401 +from .auto_parallel.api import shard_tensor # noqa: F401 from .fleet import BoxPSDataset # noqa: F401 @@ -120,4 +123,7 @@ __all__ = [ # noqa "reduce_scatter", "is_available", "get_backend", + "ProcessMesh", + "DistAttr", + "shard_tensor", ] diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py new file mode 100644 index 00000000000..b25799d058a --- /dev/null +++ b/python/paddle/distributed/auto_parallel/api.py @@ -0,0 +1,126 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.distributed.auto_parallel.process_mesh import ProcessMesh +from paddle.framework import core + +# There are the auto parallel API of the unified version of dynamic and static mode. +# Some APIs have the same name with the previous APIs implementation, which are +# a temporary state, and the APIs here will eventually be used. + + +class DistAttr(core.TensorDistAttr): + """ + DistAttr specifies how tensors are distributed or sliced on ProcessMesh. + + Args: + mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes. + sharding_specs(list[str|None]): The specification describing how to shard the Tensor. + + Examples: + + .. code-block:: python + + import paddle + import paddle.distributed as dist + + mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]) + dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y']) + + print(dist_attr) + """ + + def __init__(self, mesh, sharding_specs): + # 1. inputs checking + if not isinstance(mesh, ProcessMesh): + raise ValueError( + "The mesh must be an instance of paddle.distributed.ProcessMesh." + ) + if not isinstance(sharding_specs, list): + raise ValueError("The sharding_specs must be an instance of list.") + assert all( + isinstance(dim_name, str) or dim_name is None + for dim_name in sharding_specs + ), 'The dimension name in sharding_specs must be an instance of str.' + + dims_mapping = [ + mesh.dim_names.index(dim_name) if dim_name is not None else -1 + for dim_name in sharding_specs + ] + + # 2. init core.TensorDistAttr + core.TensorDistAttr.__init__(self) + self.process_mesh = mesh + self.dims_mapping = dims_mapping + + +def shard_tensor( + data, dtype=None, place=None, stop_gradient=True, dist_attr=None +): + """ + Constructs a ``paddle.Tensor`` with distributed attributes from ``data``, + which can scalar, tuple, list, numpy.ndarray, paddle.Tensor. + + If the ``data`` is already a Tensor, transform it to a Distributed Tensor. + + Args: + data(scalar|tuple|list|ndarray|Tensor): Initial data for the tensor. + Can be a scalar, list, tuple, numpy.ndarray, paddle.Tensor. + dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' , + 'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8', + 'complex64' , 'complex128'. Default: None, infers dtype from ``data`` + except for python float number which gets dtype from ``get_default_type`` . + place(CPUPlace|CUDAPinnedPlace|CUDAPlace|str, optional): The place to allocate Tensor. Can be + CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``place`` is + string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs. + stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True. + dist_attr(paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh. + + Returns: + Tensor: A Tensor constructed from ``data`` with distributed attributes. + + Examples: + + .. code-block:: python + + import paddle + import paddle.distributed as dist + + mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]) + dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y']) + + # dense tensor + a = paddle.to_tensor([[1,2,3], + [5,6,7]]) + # distributed tensor + d_tensor = dist.shard_tensor(a, dist_attr=dist_attr) + + print(d_tensor) + """ + # 1. create dense tensor + # `paddle.to_tensor` supports both dynamic and static mode + data = paddle.to_tensor(data) + + # 2. create dist tensor + assert len(dist_attr.dims_mapping) == len( + list(data.shape) + ), "The length of sharding_specs must be same as the shape of the input tensor." + + if paddle.in_dynamic_mode(): + return paddle.Tensor(data, dist_attr=dist_attr) + else: + raise NotImplementedError( + "The `paddle.distributed.shard_tensor` for static mode will be implemented later." + ) diff --git a/python/paddle/distributed/auto_parallel/process_mesh.py b/python/paddle/distributed/auto_parallel/process_mesh.py index 1c2f292e5f8..a6ad3355d7d 100644 --- a/python/paddle/distributed/auto_parallel/process_mesh.py +++ b/python/paddle/distributed/auto_parallel/process_mesh.py @@ -82,8 +82,9 @@ class ProcessMesh(core.ProcessMesh): .. code-block:: python import paddle + import paddle.distributed as dist - mesh = auto.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]) + mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]) assert mesh.shape == [2, 3] assert mesh.process_ids == [2, 4, 5, 0, 1, 3] @@ -162,6 +163,13 @@ class ProcessMesh(core.ProcessMesh): """ return self._mesh + @property + def dim_names(self): + """ + Get the underlying dimension names of ProcessMesh. + """ + return self._dim_names + @property def unique_id(self): """ diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index fe67e34c03d..83c2ae84182 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -153,6 +153,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_engine_save_load MODULES test_engine_save_load) py_test_modules(test_rule_based_tuner MODULES test_rule_based_tuner) py_test_modules(test_dist_tensor MODULES test_dist_tensor) + py_test_modules(test_shard_tensor_api MODULES test_shard_tensor_api) # End of unittests WITH single card WITHOUT timeout endif() diff --git a/test/auto_parallel/test_dist_tensor.py b/test/auto_parallel/test_dist_tensor.py index 58ebc085004..61705a322e2 100644 --- a/test/auto_parallel/test_dist_tensor.py +++ b/test/auto_parallel/test_dist_tensor.py @@ -17,20 +17,22 @@ import unittest import numpy as np import paddle +import paddle.distributed as dist class TestDistTensor(unittest.TestCase): def test_dist_tensor_creation(self): shape = [10, 5] - dist_attr = paddle.fluid.core.TensorDistAttr() + mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) + dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y']) # create dist tensor using numpy - dist_tensor_with_numpy = paddle.Tensor( + dist_tensor_with_numpy = dist.shard_tensor( np.ones(shape, dtype=np.float32), dist_attr=dist_attr ) # create dist tensor using tensor - dist_tensor_with_tensor = paddle.Tensor( + dist_tensor_with_tensor = dist.shard_tensor( paddle.ones(shape), dist_attr=dist_attr ) diff --git a/test/auto_parallel/test_shard_tensor_api.py b/test/auto_parallel/test_shard_tensor_api.py new file mode 100644 index 00000000000..124c7dc7ba3 --- /dev/null +++ b/test/auto_parallel/test_shard_tensor_api.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +import paddle.distributed as dist + + +class TestDistAttrBasic(unittest.TestCase): + def test_mesh_argument_error(self): + exception = None + try: + mesh = [[0, 1], [2, 3]] + dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y']) + except ValueError as ex: + self.assertIn( + "The mesh must be an instance of paddle.distributed.ProcessMesh", + str(ex), + ) + exception = ex + + self.assertIsNotNone(exception) + + def test_sharding_specs_argument_error(self): + exception = None + try: + mesh = dist.ProcessMesh( + [[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"] + ) + dist_attr = dist.DistAttr( + mesh=mesh, sharding_specs={"x": 0, "y": 1} + ) + except ValueError as ex: + self.assertIn( + "The sharding_specs must be an instance of list", str(ex) + ) + exception = ex + + self.assertIsNotNone(exception) + + +class TestShardTensorBasic(unittest.TestCase): + # remove this test after static mode is supported + def test_static_mode_unimplemented(self): + exception = None + try: + paddle.enable_static() + mesh = dist.ProcessMesh( + [[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"] + ) + dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y']) + a = paddle.to_tensor([[1, 2, 3], [5, 6, 7]]) + d_tensor = dist.shard_tensor(a, dist_attr=dist_attr) + except NotImplementedError as ex: + self.assertIn( + "The `paddle.distributed.shard_tensor` for static mode will be implemented later", + str(ex), + ) + exception = ex + paddle.disable_static() + + self.assertIsNotNone(exception) + + +if __name__ == "__main__": + unittest.main() -- GitLab