未验证 提交 39540b0e 编写于 作者: L lilong12 提交者: GitHub

add checkers for auto parallel apis (#35486)

* update, test=develop
上级 c4a3e8b4
...@@ -271,12 +271,22 @@ class ProcessMesh(object): ...@@ -271,12 +271,22 @@ class ProcessMesh(object):
def _dim_mapping_checker(tensor, mesh, dim_mapping): def _dim_mapping_checker(tensor, mesh, dim_mapping):
assert len(tensor.shape) == len(dim_mapping) assert isinstance(mesh,
ProcessMesh), 'The type of mesh must be ProcessMesh.'
assert isinstance(dim_mapping,
list), 'The type of dim_mapping must be list.'
assert len(tensor.shape) == len(dim_mapping), (
'The number of dimensions '
'of tensor must be the same as the length of its corresponding '
'dim_mapping.')
mesh_dim = len(mesh.topology) mesh_dim = len(mesh.topology)
dim_set = set() dim_set = set()
for i in range(len(dim_mapping)): for i in range(len(dim_mapping)):
assert dim_mapping[i] == -1 or (dim_mapping[i] < mesh_dim and assert dim_mapping[i] == -1 or (
dim_mapping[i] >= 0) dim_mapping[i] < mesh_dim and dim_mapping[i] >= 0), (
'Each element '
'in dim_mapping must be greater than zero and less than the '
'length of its corresponding topology, or it must be -1.')
if dim_mapping[i] >= 0: if dim_mapping[i] >= 0:
assert dim_mapping[i] not in dim_set assert dim_mapping[i] not in dim_set
dim_set.add(dim_mapping[i]) dim_set.add(dim_mapping[i])
...@@ -347,6 +357,7 @@ def set_shard_mask(x, mask): ...@@ -347,6 +357,7 @@ def set_shard_mask(x, mask):
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]]) mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
mask = [[1, 0, 1], [0, 1, 0]] mask = [[1, 0, 1], [0, 1, 0]]
x = paddle.ones([4, 6]) x = paddle.ones([4, 6])
dist.shard_tensor(x, mesh, [-1, 1])
dist.set_shard_mask(x, mask) dist.set_shard_mask(x, mask)
""" """
...@@ -355,6 +366,9 @@ def set_shard_mask(x, mask): ...@@ -355,6 +366,9 @@ def set_shard_mask(x, mask):
np_mask = numpy.array(mask) np_mask = numpy.array(mask)
min_ele = numpy.min(np_mask) min_ele = numpy.min(np_mask)
max_ele = numpy.max(np_mask) max_ele = numpy.max(np_mask)
mesh_attr_name = _append_attr_suffix('mesh_id')
assert x._has_attr(mesh_attr_name), \
"Please set process mesh for the variable firstly."
assert min_ele >= 0 and max_ele <= 1, "Elements in mask must be 0 or 1." assert min_ele >= 0 and max_ele <= 1, "Elements in mask must be 0 or 1."
x_mesh = x.process_mesh x_mesh = x.process_mesh
assert x_mesh, "Please set process mesh for the variable firstly." assert x_mesh, "Please set process mesh for the variable firstly."
...@@ -403,7 +417,15 @@ def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs): ...@@ -403,7 +417,15 @@ def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs):
op_size = len(main_block.ops) op_size = len(main_block.ops)
output = op_fn(**kwargs) output = op_fn(**kwargs)
new_op_size = len(main_block.ops) new_op_size = len(main_block.ops)
if dim_mapping_dict is None: dim_mapping_dict = dict() if dim_mapping_dict is None:
dim_mapping_dict = dict()
else:
assert isinstance(dim_mapping_dict,
dict), 'The type of dim_mapping_dict must be dict.'
for var_name in dim_mapping_dict.keys():
dim_mapping = dim_mapping_dict[var_name]
tensor = main_block.var(var_name)
_dim_mapping_checker(tensor, mesh, dim_mapping)
for idx in range(op_size, new_op_size): for idx in range(op_size, new_op_size):
op = main_block.ops[idx] op = main_block.ops[idx]
attr_name = _append_attr_suffix('mesh_id') attr_name = _append_attr_suffix('mesh_id')
...@@ -477,4 +499,5 @@ def set_pipeline_stage(stage): ...@@ -477,4 +499,5 @@ def set_pipeline_stage(stage):
""" """
from paddle.fluid.framework import _set_pipeline_stage from paddle.fluid.framework import _set_pipeline_stage
_static_mode_check() _static_mode_check()
assert isinstance(stage, int), 'The type of stage must be int.'
_set_pipeline_stage(stage) _set_pipeline_stage(stage)
...@@ -97,8 +97,8 @@ class TestAutoParallelAPI(unittest.TestCase): ...@@ -97,8 +97,8 @@ class TestAutoParallelAPI(unittest.TestCase):
self.assertEqual(last_op.pipeline_stage, LAST_PP_STAGE) self.assertEqual(last_op.pipeline_stage, LAST_PP_STAGE)
DIMS_MAPPING1 = [0, 1, -1] DIMS_MAPPING1 = [0, 1]
DIMS_MAPPING2 = [-1, 2, 0] DIMS_MAPPING2 = [-1, 0]
kwargs = {'x': data2, 'y': data3} kwargs = {'x': data2, 'y': data3}
dist.shard_op( dist.shard_op(
paddle.add, paddle.add,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册