未验证 提交 dadfb099 编写于 作者: L Leo Chen 提交者: GitHub

[SemiAuto] add static branch for shard_tensor (#56561)

* shard_tensor support static graph

* add comments

* add dy2static ut

* use property in c++ side
上级 a4c8d977
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.interface import (
shard_tensor as shard_tensor_static,
)
from paddle.framework import core from paddle.framework import core
# There are the auto parallel API of the unified version of dynamic and static mode. # There are the auto parallel API of the unified version of dynamic and static mode.
...@@ -44,7 +46,7 @@ class DistAttr(core.TensorDistAttr): ...@@ -44,7 +46,7 @@ class DistAttr(core.TensorDistAttr):
def __init__(self, mesh, sharding_specs): def __init__(self, mesh, sharding_specs):
# 1. inputs checking # 1. inputs checking
if not isinstance(mesh, ProcessMesh): if not isinstance(mesh, core.ProcessMesh):
raise ValueError( raise ValueError(
"The mesh must be an instance of paddle.distributed.ProcessMesh." "The mesh must be an instance of paddle.distributed.ProcessMesh."
) )
...@@ -55,6 +57,7 @@ class DistAttr(core.TensorDistAttr): ...@@ -55,6 +57,7 @@ class DistAttr(core.TensorDistAttr):
for dim_name in sharding_specs for dim_name in sharding_specs
), 'The dimension name in sharding_specs must be an instance of str.' ), 'The dimension name in sharding_specs must be an instance of str.'
self._sharding_specs = sharding_specs
dims_mapping = [ dims_mapping = [
mesh.dim_names.index(dim_name) if dim_name is not None else -1 mesh.dim_names.index(dim_name) if dim_name is not None else -1
for dim_name in sharding_specs for dim_name in sharding_specs
...@@ -62,9 +65,23 @@ class DistAttr(core.TensorDistAttr): ...@@ -62,9 +65,23 @@ class DistAttr(core.TensorDistAttr):
# 2. init core.TensorDistAttr # 2. init core.TensorDistAttr
core.TensorDistAttr.__init__(self) core.TensorDistAttr.__init__(self)
self.process_mesh = mesh self.process_mesh = mesh
self.dims_mapping = dims_mapping self.dims_mapping = dims_mapping
self.mark_annotated("process_mesh")
self.mark_annotated("dims_mapping")
@property
def sharding_specs(self):
"""
Get sharding_specs of the dist_attr
Returns:
list[str]: sharding_specs
"""
return self._sharding_specs
def shard_tensor( def shard_tensor(
data, dtype=None, place=None, stop_gradient=True, dist_attr=None data, dtype=None, place=None, stop_gradient=True, dist_attr=None
...@@ -121,6 +138,7 @@ def shard_tensor( ...@@ -121,6 +138,7 @@ def shard_tensor(
if paddle.in_dynamic_mode(): if paddle.in_dynamic_mode():
return paddle.Tensor(data, dist_attr=dist_attr) return paddle.Tensor(data, dist_attr=dist_attr)
else: else:
raise NotImplementedError( # TODO(zhiqiu): we need to refine the static shard_tensor
"The `paddle.distributed.shard_tensor` for static mode will be implemented later." return shard_tensor_static(
data, dist_attr.process_mesh, dist_attr.sharding_specs
) )
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle.framework import core
from .process_mesh import ProcessMesh, get_current_process_mesh from .process_mesh import ProcessMesh, get_current_process_mesh
from .static.dist_context import get_default_distributed_context from .static.dist_context import get_default_distributed_context
...@@ -67,7 +68,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None): ...@@ -67,7 +68,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
if process_mesh is not None: if process_mesh is not None:
assert isinstance( assert isinstance(
process_mesh, ProcessMesh process_mesh, core.ProcessMesh
), "Argument process_mesh {} is not an instance of ProcessMesh".format( ), "Argument process_mesh {} is not an instance of ProcessMesh".format(
process_mesh process_mesh
) )
......
...@@ -16,6 +16,10 @@ import unittest ...@@ -16,6 +16,10 @@ import unittest
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context,
)
from paddle.fluid.dygraph.base import switch_to_static_graph
class TestDistAttrBasic(unittest.TestCase): class TestDistAttrBasic(unittest.TestCase):
...@@ -51,27 +55,80 @@ class TestDistAttrBasic(unittest.TestCase): ...@@ -51,27 +55,80 @@ class TestDistAttrBasic(unittest.TestCase):
self.assertIsNotNone(exception) self.assertIsNotNone(exception)
class TestShardTensorBasic(unittest.TestCase): class TestShardTensorDynamic(unittest.TestCase):
# remove this test after static mode is supported def setUp(self):
def test_static_mode_unimplemented(self): self.mesh = dist.ProcessMesh(
exception = None [[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
try: )
paddle.enable_static()
def test_dynamic(self):
dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=['x', None, None]
)
input = paddle.rand([4, 1024, 512])
d_tensor = dist.shard_tensor(input, dist_attr=dist_attr)
print(dist_attr.dims_mapping)
self.assertEqual(d_tensor.dist_attr.process_mesh, self.mesh)
self.assertEqual(d_tensor.dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(d_tensor.dist_attr.is_annotated("process_mesh"))
self.assertTrue(d_tensor.dist_attr.is_annotated("dims_mapping"))
class TestShardTensorStatic(unittest.TestCase):
def setUp(self):
self.mesh = dist.ProcessMesh(
[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
)
@switch_to_static_graph
def test_static_mode(self):
dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=['x', None, None]
)
input = paddle.static.data(
name="input",
shape=[4, 1024, 512],
dtype='float32',
)
d_tensor = dist.shard_tensor(input, dist_attr=dist_attr)
default_dist_context = get_default_distributed_context()
dist_input = default_dist_context.get_dist_tensor_for_program(input)
self.assertEqual(dist_input.dist_attr.process_mesh, self.mesh)
self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))
class TestShardTensorStaticDy2Static(unittest.TestCase):
def test_dy2static(self):
@paddle.jit.to_static
def func():
mesh = dist.ProcessMesh( mesh = dist.ProcessMesh(
[[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"] [[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
) )
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y']) dist_attr = dist.DistAttr(
a = paddle.to_tensor([[1, 2, 3], [5, 6, 7]]) mesh=mesh, sharding_specs=['x', None, None]
d_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
except NotImplementedError as ex:
self.assertIn(
"The `paddle.distributed.shard_tensor` for static mode will be implemented later",
str(ex),
) )
exception = ex
paddle.disable_static()
self.assertIsNotNone(exception) input = paddle.rand([4, 1024, 512])
d_tensor = dist.shard_tensor(input, dist_attr=dist_attr)
return input, mesh
dy_tensor, mesh = func()
static_tensor = func.outputs[0] # get the inputs of static program
default_dist_context = get_default_distributed_context()
dist_input = default_dist_context.get_dist_tensor_for_program(
static_tensor
)
self.assertEqual(dist_input.dist_attr.process_mesh, mesh)
self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册