diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index 348edaef68198ce511e83b55956186a4ecad0349..30055c5b763a140faa0bb534579e2480e74651f6 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -271,12 +271,22 @@ class ProcessMesh(object): def _dim_mapping_checker(tensor, mesh, dim_mapping): - assert len(tensor.shape) == len(dim_mapping) + assert isinstance(mesh, + ProcessMesh), 'The type of mesh must be ProcessMesh.' + assert isinstance(dim_mapping, + list), 'The type of dim_mapping must be list.' + assert len(tensor.shape) == len(dim_mapping), ( + 'The number of dimensions ' + 'of tensor must be the same as the length of its corresponding ' + 'dim_mapping.') mesh_dim = len(mesh.topology) dim_set = set() for i in range(len(dim_mapping)): - assert dim_mapping[i] == -1 or (dim_mapping[i] < mesh_dim and - dim_mapping[i] >= 0) + assert dim_mapping[i] == -1 or ( + dim_mapping[i] < mesh_dim and dim_mapping[i] >= 0), ( + 'Each element ' + 'in dim_mapping must be greater than zero and less than the ' + 'length of its corresponding topology, or it must be -1.') if dim_mapping[i] >= 0: assert dim_mapping[i] not in dim_set dim_set.add(dim_mapping[i]) @@ -347,6 +357,7 @@ def set_shard_mask(x, mask): mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]]) mask = [[1, 0, 1], [0, 1, 0]] x = paddle.ones([4, 6]) + dist.shard_tensor(x, mesh, [-1, 1]) dist.set_shard_mask(x, mask) """ @@ -355,6 +366,9 @@ def set_shard_mask(x, mask): np_mask = numpy.array(mask) min_ele = numpy.min(np_mask) max_ele = numpy.max(np_mask) + mesh_attr_name = _append_attr_suffix('mesh_id') + assert x._has_attr(mesh_attr_name), \ + "Please set process mesh for the variable firstly." assert min_ele >= 0 and max_ele <= 1, "Elements in mask must be 0 or 1." x_mesh = x.process_mesh assert x_mesh, "Please set process mesh for the variable firstly." @@ -403,7 +417,15 @@ def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs): op_size = len(main_block.ops) output = op_fn(**kwargs) new_op_size = len(main_block.ops) - if dim_mapping_dict is None: dim_mapping_dict = dict() + if dim_mapping_dict is None: + dim_mapping_dict = dict() + else: + assert isinstance(dim_mapping_dict, + dict), 'The type of dim_mapping_dict must be dict.' + for var_name in dim_mapping_dict.keys(): + dim_mapping = dim_mapping_dict[var_name] + tensor = main_block.var(var_name) + _dim_mapping_checker(tensor, mesh, dim_mapping) for idx in range(op_size, new_op_size): op = main_block.ops[idx] attr_name = _append_attr_suffix('mesh_id') @@ -477,4 +499,5 @@ def set_pipeline_stage(stage): """ from paddle.fluid.framework import _set_pipeline_stage _static_mode_check() + assert isinstance(stage, int), 'The type of stage must be int.' _set_pipeline_stage(stage) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py index 15981d461aed95b3305e15a27081b79f87839be7..3f1d692b72e984fe416968af3416ea0f2a210747 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py @@ -97,8 +97,8 @@ class TestAutoParallelAPI(unittest.TestCase): self.assertEqual(last_op.pipeline_stage, LAST_PP_STAGE) - DIMS_MAPPING1 = [0, 1, -1] - DIMS_MAPPING2 = [-1, 2, 0] + DIMS_MAPPING1 = [0, 1] + DIMS_MAPPING2 = [-1, 0] kwargs = {'x': data2, 'y': data3} dist.shard_op( paddle.add,