未验证 提交 dadfb099 编写于 作者: L Leo Chen 提交者: GitHub

[SemiAuto] add static branch for shard_tensor (#56561)

* shard_tensor support static graph

* add comments

* add dy2static ut

* use property in c++ side
上级 a4c8d977
......@@ -13,7 +13,9 @@
# limitations under the License.
import paddle
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.interface import (
shard_tensor as shard_tensor_static,
)
from paddle.framework import core
# There are the auto parallel API of the unified version of dynamic and static mode.
......@@ -44,7 +46,7 @@ class DistAttr(core.TensorDistAttr):
def __init__(self, mesh, sharding_specs):
# 1. inputs checking
if not isinstance(mesh, ProcessMesh):
if not isinstance(mesh, core.ProcessMesh):
raise ValueError(
"The mesh must be an instance of paddle.distributed.ProcessMesh."
)
......@@ -55,6 +57,7 @@ class DistAttr(core.TensorDistAttr):
for dim_name in sharding_specs
), 'The dimension name in sharding_specs must be an instance of str.'
self._sharding_specs = sharding_specs
dims_mapping = [
mesh.dim_names.index(dim_name) if dim_name is not None else -1
for dim_name in sharding_specs
......@@ -62,9 +65,23 @@ class DistAttr(core.TensorDistAttr):
# 2. init core.TensorDistAttr
core.TensorDistAttr.__init__(self)
self.process_mesh = mesh
self.dims_mapping = dims_mapping
self.mark_annotated("process_mesh")
self.mark_annotated("dims_mapping")
@property
def sharding_specs(self):
"""
Get sharding_specs of the dist_attr
Returns:
list[str]: sharding_specs
"""
return self._sharding_specs
def shard_tensor(
data, dtype=None, place=None, stop_gradient=True, dist_attr=None
......@@ -121,6 +138,7 @@ def shard_tensor(
if paddle.in_dynamic_mode():
return paddle.Tensor(data, dist_attr=dist_attr)
else:
raise NotImplementedError(
"The `paddle.distributed.shard_tensor` for static mode will be implemented later."
# TODO(zhiqiu): we need to refine the static shard_tensor
return shard_tensor_static(
data, dist_attr.process_mesh, dist_attr.sharding_specs
)
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
from paddle.framework import core
from .process_mesh import ProcessMesh, get_current_process_mesh
from .static.dist_context import get_default_distributed_context
......@@ -67,7 +68,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
if process_mesh is not None:
assert isinstance(
process_mesh, ProcessMesh
process_mesh, core.ProcessMesh
), "Argument process_mesh {} is not an instance of ProcessMesh".format(
process_mesh
)
......
......@@ -16,6 +16,10 @@ import unittest
import paddle
import paddle.distributed as dist
from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context,
)
from paddle.fluid.dygraph.base import switch_to_static_graph
class TestDistAttrBasic(unittest.TestCase):
......@@ -51,27 +55,80 @@ class TestDistAttrBasic(unittest.TestCase):
self.assertIsNotNone(exception)
class TestShardTensorBasic(unittest.TestCase):
# remove this test after static mode is supported
def test_static_mode_unimplemented(self):
exception = None
try:
paddle.enable_static()
class TestShardTensorDynamic(unittest.TestCase):
def setUp(self):
self.mesh = dist.ProcessMesh(
[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
)
def test_dynamic(self):
dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=['x', None, None]
)
input = paddle.rand([4, 1024, 512])
d_tensor = dist.shard_tensor(input, dist_attr=dist_attr)
print(dist_attr.dims_mapping)
self.assertEqual(d_tensor.dist_attr.process_mesh, self.mesh)
self.assertEqual(d_tensor.dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(d_tensor.dist_attr.is_annotated("process_mesh"))
self.assertTrue(d_tensor.dist_attr.is_annotated("dims_mapping"))
class TestShardTensorStatic(unittest.TestCase):
def setUp(self):
self.mesh = dist.ProcessMesh(
[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
)
@switch_to_static_graph
def test_static_mode(self):
dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=['x', None, None]
)
input = paddle.static.data(
name="input",
shape=[4, 1024, 512],
dtype='float32',
)
d_tensor = dist.shard_tensor(input, dist_attr=dist_attr)
default_dist_context = get_default_distributed_context()
dist_input = default_dist_context.get_dist_tensor_for_program(input)
self.assertEqual(dist_input.dist_attr.process_mesh, self.mesh)
self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))
class TestShardTensorStaticDy2Static(unittest.TestCase):
def test_dy2static(self):
@paddle.jit.to_static
def func():
mesh = dist.ProcessMesh(
[[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]
[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
)
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
a = paddle.to_tensor([[1, 2, 3], [5, 6, 7]])
d_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
except NotImplementedError as ex:
self.assertIn(
"The `paddle.distributed.shard_tensor` for static mode will be implemented later",
str(ex),
dist_attr = dist.DistAttr(
mesh=mesh, sharding_specs=['x', None, None]
)
exception = ex
paddle.disable_static()
self.assertIsNotNone(exception)
input = paddle.rand([4, 1024, 512])
d_tensor = dist.shard_tensor(input, dist_attr=dist_attr)
return input, mesh
dy_tensor, mesh = func()
static_tensor = func.outputs[0] # get the inputs of static program
default_dist_context = get_default_distributed_context()
dist_input = default_dist_context.get_dist_tensor_for_program(
static_tensor
)
self.assertEqual(dist_input.dist_attr.process_mesh, mesh)
self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册