未验证 提交 bd60757d 编写于 作者: C Chen Weihang 提交者: GitHub

[AutoParallel] Add shard tensor and DistAttr api (#55494)

* add shard tensor api

* add DistAttr api

* add unittest for coverage

* fix process mesh sample code

* fix checking error
上级 2b8e6285
...@@ -61,8 +61,11 @@ from .communication import ( ...@@ -61,8 +61,11 @@ from .communication import (
get_backend, get_backend,
) # noqa: F401 ) # noqa: F401
from .auto_parallel.process_mesh import ProcessMesh # noqa: F401
from .auto_parallel.api import DistAttr # noqa: F401
from .auto_parallel import shard_op # noqa: F401 from .auto_parallel import shard_op # noqa: F401
from .auto_parallel import shard_tensor # noqa: F401 from .auto_parallel.api import shard_tensor # noqa: F401
from .fleet import BoxPSDataset # noqa: F401 from .fleet import BoxPSDataset # noqa: F401
...@@ -120,4 +123,7 @@ __all__ = [ # noqa ...@@ -120,4 +123,7 @@ __all__ = [ # noqa
"reduce_scatter", "reduce_scatter",
"is_available", "is_available",
"get_backend", "get_backend",
"ProcessMesh",
"DistAttr",
"shard_tensor",
] ]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.framework import core
# There are the auto parallel API of the unified version of dynamic and static mode.
# Some APIs have the same name with the previous APIs implementation, which are
# a temporary state, and the APIs here will eventually be used.
class DistAttr(core.TensorDistAttr):
"""
DistAttr specifies how tensors are distributed or sliced on ProcessMesh.
Args:
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
sharding_specs(list[str|None]): The specification describing how to shard the Tensor.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
print(dist_attr)
"""
def __init__(self, mesh, sharding_specs):
# 1. inputs checking
if not isinstance(mesh, ProcessMesh):
raise ValueError(
"The mesh must be an instance of paddle.distributed.ProcessMesh."
)
if not isinstance(sharding_specs, list):
raise ValueError("The sharding_specs must be an instance of list.")
assert all(
isinstance(dim_name, str) or dim_name is None
for dim_name in sharding_specs
), 'The dimension name in sharding_specs must be an instance of str.'
dims_mapping = [
mesh.dim_names.index(dim_name) if dim_name is not None else -1
for dim_name in sharding_specs
]
# 2. init core.TensorDistAttr
core.TensorDistAttr.__init__(self)
self.process_mesh = mesh
self.dims_mapping = dims_mapping
def shard_tensor(
data, dtype=None, place=None, stop_gradient=True, dist_attr=None
):
"""
Constructs a ``paddle.Tensor`` with distributed attributes from ``data``,
which can scalar, tuple, list, numpy.ndarray, paddle.Tensor.
If the ``data`` is already a Tensor, transform it to a Distributed Tensor.
Args:
data(scalar|tuple|list|ndarray|Tensor): Initial data for the tensor.
Can be a scalar, list, tuple, numpy.ndarray, paddle.Tensor.
dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' ,
'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8',
'complex64' , 'complex128'. Default: None, infers dtype from ``data``
except for python float number which gets dtype from ``get_default_type`` .
place(CPUPlace|CUDAPinnedPlace|CUDAPlace|str, optional): The place to allocate Tensor. Can be
CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``place`` is
string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs.
stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True.
dist_attr(paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh.
Returns:
Tensor: A Tensor constructed from ``data`` with distributed attributes.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
# dense tensor
a = paddle.to_tensor([[1,2,3],
[5,6,7]])
# distributed tensor
d_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
print(d_tensor)
"""
# 1. create dense tensor
# `paddle.to_tensor` supports both dynamic and static mode
data = paddle.to_tensor(data)
# 2. create dist tensor
assert len(dist_attr.dims_mapping) == len(
list(data.shape)
), "The length of sharding_specs must be same as the shape of the input tensor."
if paddle.in_dynamic_mode():
return paddle.Tensor(data, dist_attr=dist_attr)
else:
raise NotImplementedError(
"The `paddle.distributed.shard_tensor` for static mode will be implemented later."
)
...@@ -82,8 +82,9 @@ class ProcessMesh(core.ProcessMesh): ...@@ -82,8 +82,9 @@ class ProcessMesh(core.ProcessMesh):
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.distributed as dist
mesh = auto.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]) mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
assert mesh.shape == [2, 3] assert mesh.shape == [2, 3]
assert mesh.process_ids == [2, 4, 5, 0, 1, 3] assert mesh.process_ids == [2, 4, 5, 0, 1, 3]
...@@ -162,6 +163,13 @@ class ProcessMesh(core.ProcessMesh): ...@@ -162,6 +163,13 @@ class ProcessMesh(core.ProcessMesh):
""" """
return self._mesh return self._mesh
@property
def dim_names(self):
"""
Get the underlying dimension names of ProcessMesh.
"""
return self._dim_names
@property @property
def unique_id(self): def unique_id(self):
""" """
......
...@@ -153,6 +153,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -153,6 +153,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_save_load MODULES test_engine_save_load) py_test_modules(test_engine_save_load MODULES test_engine_save_load)
py_test_modules(test_rule_based_tuner MODULES test_rule_based_tuner) py_test_modules(test_rule_based_tuner MODULES test_rule_based_tuner)
py_test_modules(test_dist_tensor MODULES test_dist_tensor) py_test_modules(test_dist_tensor MODULES test_dist_tensor)
py_test_modules(test_shard_tensor_api MODULES test_shard_tensor_api)
# End of unittests WITH single card WITHOUT timeout # End of unittests WITH single card WITHOUT timeout
endif() endif()
...@@ -17,20 +17,22 @@ import unittest ...@@ -17,20 +17,22 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.distributed as dist
class TestDistTensor(unittest.TestCase): class TestDistTensor(unittest.TestCase):
def test_dist_tensor_creation(self): def test_dist_tensor_creation(self):
shape = [10, 5] shape = [10, 5]
dist_attr = paddle.fluid.core.TensorDistAttr() mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
# create dist tensor using numpy # create dist tensor using numpy
dist_tensor_with_numpy = paddle.Tensor( dist_tensor_with_numpy = dist.shard_tensor(
np.ones(shape, dtype=np.float32), dist_attr=dist_attr np.ones(shape, dtype=np.float32), dist_attr=dist_attr
) )
# create dist tensor using tensor # create dist tensor using tensor
dist_tensor_with_tensor = paddle.Tensor( dist_tensor_with_tensor = dist.shard_tensor(
paddle.ones(shape), dist_attr=dist_attr paddle.ones(shape), dist_attr=dist_attr
) )
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.distributed as dist
class TestDistAttrBasic(unittest.TestCase):
def test_mesh_argument_error(self):
exception = None
try:
mesh = [[0, 1], [2, 3]]
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
except ValueError as ex:
self.assertIn(
"The mesh must be an instance of paddle.distributed.ProcessMesh",
str(ex),
)
exception = ex
self.assertIsNotNone(exception)
def test_sharding_specs_argument_error(self):
exception = None
try:
mesh = dist.ProcessMesh(
[[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]
)
dist_attr = dist.DistAttr(
mesh=mesh, sharding_specs={"x": 0, "y": 1}
)
except ValueError as ex:
self.assertIn(
"The sharding_specs must be an instance of list", str(ex)
)
exception = ex
self.assertIsNotNone(exception)
class TestShardTensorBasic(unittest.TestCase):
# remove this test after static mode is supported
def test_static_mode_unimplemented(self):
exception = None
try:
paddle.enable_static()
mesh = dist.ProcessMesh(
[[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]
)
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
a = paddle.to_tensor([[1, 2, 3], [5, 6, 7]])
d_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
except NotImplementedError as ex:
self.assertIn(
"The `paddle.distributed.shard_tensor` for static mode will be implemented later",
str(ex),
)
exception = ex
paddle.disable_static()
self.assertIsNotNone(exception)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册