From fc5fb2a163a5dbdfb9e643875f0fcbe0a6f31348 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Wed, 15 Sep 2021 17:30:31 +0800 Subject: [PATCH] add dist_attr for dist op and var (#35585) * add dist_attr for dist op * add unitest * update inputname * update function name * add unitest * update CMakeLists.txt for CI * fix dis_matmul * fix compile error * update matmul to matmul_v2 --- paddle/fluid/operators/searchsorted_op.h | 32 ++- .../auto_parallel/operators/common.py | 47 ++++ .../auto_parallel/operators/dist_embedding.py | 15 +- .../auto_parallel/operators/dist_matmul.py | 74 ++++-- .../fluid/tests/unittests/CMakeLists.txt | 2 + .../test_auto_parallel_partitioner.py | 242 +++++++++++++++--- 6 files changed, 354 insertions(+), 58 deletions(-) diff --git a/paddle/fluid/operators/searchsorted_op.h b/paddle/fluid/operators/searchsorted_op.h index 5ae0e79907..9eec755293 100644 --- a/paddle/fluid/operators/searchsorted_op.h +++ b/paddle/fluid/operators/searchsorted_op.h @@ -30,13 +30,37 @@ using Tensor = framework::Tensor; template class GpuAndCpuSearchSortedCompute { public: - static HOSTDEVICE bool IsNan(float x) { return ::isnan(x); } - static HOSTDEVICE bool IsNan(double x) { return ::isnan(x); } + static HOSTDEVICE bool IsNan(float x) { +#ifdef __NVCC__ + return ::isnan(x); +#else + return std::isnan(x); +#endif + } + static HOSTDEVICE bool IsNan(double x) { +#ifdef __NVCC__ + return ::isnan(x); +#else + return std::isnan(x); +#endif + } static HOSTDEVICE bool IsNan(int x) { return false; } static HOSTDEVICE bool IsNan(int64_t x) { return false; } - static HOSTDEVICE bool IsInf(float x) { return ::isinf(x); } - static HOSTDEVICE bool IsInf(double x) { return ::isinf(x); } + static HOSTDEVICE bool IsInf(float x) { +#ifdef __NVCC__ + return ::isinf(x); +#else + return std::isinf(x); +#endif + } + static HOSTDEVICE bool IsInf(double x) { +#ifdef __NVCC__ + return ::isinf(x); +#else + return std::isinf(x); +#endif + } static HOSTDEVICE bool IsInf(int x) { return false; } static HOSTDEVICE bool IsInf(int64_t x) { return false; } diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index ef2f508344..1b0b05d395 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -114,3 +114,50 @@ def find_best_compatible_distributed_operator_impl(name, op_dist_attr, best_compatible_impl, idx = None, -1 return best_compatible_impl, idx + + +def copy_distributed_attr_for_var(src_op_dist_attr, var, src_var): + """ + copy src var's dist_attr to dst var + """ + import copy + + auto_paralle_context = src_op_dist_attr.get_owner_context() + dist_attr = copy.deepcopy( + auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) + dist_attr._owner_tensor = var + dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( + src_var)._owner_context + auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr) + + +def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr): + """ + copy src op's dist_attr to dst dist op + """ + from ..attribute import OperatorDistributedAttribute + + auto_paralle_context = src_op_dist_attr.get_owner_context() + op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context) + auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc, + op_dist_attr) + auto_paralle_context.set_op_distributed_attr_for_program(dist_op, + op_dist_attr) + + op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh()) + op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx()) + + for input_varname in dist_op.desc.input_arg_names(): + input_var = dst_block.var(input_varname) + tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( + input_var) + tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() + op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping) + + for output_varname in dist_op.desc.output_arg_names(): + output_var = dst_block.var(output_varname) + tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( + output_var) + tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() + op_dist_attr.set_output_dims_mapping(output_varname, + tensor_dims_mapping) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 5d1cfcbf69..141c3d14a7 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -16,6 +16,8 @@ from .common import DistributedOperator from .common import DistributedOperatorImpl from .common import register_distributed_operator from .common import register_distributed_operator_impl +from .common import copy_distributed_attr_for_var +from .common import copy_distributed_attr_for_dist_op from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index @@ -173,13 +175,16 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=Out_var.stop_gradient) + # copy Out_var's dist_attr to intermediate_var_0's dist_attr + copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, + Out_var) check_variable_and_dtype( Out_var, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], 'c_allreduce_sum') - dst_block.append_op( + c_embedding_op = dst_block.append_op( type='c_embedding', inputs={'Ids': [Ids_var], 'W': [Weight_var]}, @@ -187,7 +192,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): attrs={"start_index": relative_idx}) # use_model_parallel - dst_block.append_op( + c_allreduce_sum_op = dst_block.append_op( type='c_allreduce_sum', inputs={'X': [intermediate_var_0]}, outputs={'Out': [Out_var]}, @@ -197,6 +202,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): 'use_model_parallel': True, }) + # copy serial op's dist_attr to dist op's dist_attr + copy_distributed_attr_for_dist_op(c_embedding_op, dst_block, + op_dist_attr) + copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block, + op_dist_attr) + if in_dygraph_mode(): raise NotImplementedError( "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 91bad5bc34..10a01dc57e 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -16,6 +16,8 @@ from .common import DistributedOperator from .common import DistributedOperatorImpl from .common import register_distributed_operator from .common import register_distributed_operator_impl +from .common import copy_distributed_attr_for_var +from .common import copy_distributed_attr_for_dist_op from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index @@ -223,13 +225,16 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=X_var.stop_gradient) + # copy X_var's dist_attr to intermediate_var_0's dist_attr + copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, + X_var) check_variable_and_dtype( X_var, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') - dst_block.append_op( + c_identity_op = dst_block.append_op( type='c_identity', inputs={'X': [X_var]}, outputs={'Out': intermediate_var_0}, @@ -250,12 +255,18 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): 'alpha': 1, } inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} - dst_block.append_op( + matmul_op = dst_block.append_op( type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs) + # copy serial op's dist_attr to dist op's dist_attr + copy_distributed_attr_for_dist_op(c_identity_op, dst_block, + op_dist_attr) + copy_distributed_attr_for_dist_op(matmul_op, dst_block, + op_dist_attr) + if in_dygraph_mode(): raise NotImplementedError( "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( @@ -369,13 +380,17 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): persistable=False, is_data=False, need_check_feed=Out_var.desc.need_check_feed()) - dst_block.append_op( + # copy Out_var's dist_attr to intermediate_var_0's dist_attr + copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, + Out_var) + + matmul_op = dst_block.append_op( type='matmul', inputs=inputs, outputs={'Out': intermediate_var_0}, attrs=attrs) - dst_block.append_op( + c_allreduce_sum_op = dst_block.append_op( type='c_allreduce_sum', inputs={'X': intermediate_var_0}, outputs={'Out': Out_var}, @@ -385,6 +400,12 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): 'use_model_parallel': True }) + # copy serial op's dist_attr to dist op's dist_attr + copy_distributed_attr_for_dist_op(matmul_op, dst_block, + op_dist_attr) + copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block, + op_dist_attr) + if in_dygraph_mode(): raise NotImplementedError( "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( @@ -540,15 +561,12 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): Out_var = dst_block.var(output_name_mapping['Out'][0]) # TODO infer logic comm presentation - from ..process import new_process_group - from ..transpiler import _get_comm_group model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( )._get_model_parallel_info() - group_ranks = _get_comm_group(process_mesh.topology, - model_parallel_axis, - process_mesh.process_group, rank_id) + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, + model_parallel_axis, rank_id) group = new_process_group(group_ranks) - # print("@@@@@@@@@@@@@@@@@@@@@ 5", group) intermediate_var_0 = dst_block.create_var( name=unique_name.generate_with_ignorable_key(".".join( @@ -558,13 +576,16 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=X_var.stop_gradient) + # copy X_var's dist_attr to intermediate_var_0's dist_attr + copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, + X_var) check_variable_and_dtype( X_var, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') - dst_block.append_op( + c_identity_op = dst_block.append_op( type='c_identity', inputs={'X': [X_var]}, outputs={'Out': intermediate_var_0}, @@ -581,12 +602,18 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ['float16', 'float32', 'float64'], 'linear') attrs = {'trans_x': False, 'trans_y': False} inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} - dst_block.append_op( + matmul_v2_op = dst_block.append_op( type='matmul_v2', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs) + # copy serial op's dist_attr to dist op's dist_attr + copy_distributed_attr_for_dist_op(c_identity_op, dst_block, + op_dist_attr) + copy_distributed_attr_for_dist_op(matmul_v2_op, dst_block, + op_dist_attr) + if in_dygraph_mode(): raise NotImplementedError( "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( @@ -675,15 +702,12 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): Out_var = dst_block.var(output_name_mapping['Out'][0]) # TODO infer logic comm presentation - from ..process import new_process_group - from ..transpiler import _get_comm_group model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( )._get_model_parallel_info() - group_ranks = _get_comm_group(process_mesh.topology, - model_parallel_axis, - process_mesh.process_group, rank_id) + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, + model_parallel_axis, rank_id) group = new_process_group(group_ranks) - # print("@@@@@@@@@@@@@@@@@@@@@ 4", group) check_variable_and_dtype( X_var, 'x', ['float16', 'float32', 'float64'], 'linear') @@ -699,13 +723,17 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): persistable=False, is_data=False, need_check_feed=Out_var.desc.need_check_feed()) - dst_block.append_op( + # copy Out_var's dist_attr to intermediate_var_0's dist_attr + copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, + Out_var) + + matmul_v2_op = dst_block.append_op( type='matmul_v2', inputs=inputs, outputs={'Out': intermediate_var_0}, attrs=attrs) - dst_block.append_op( + c_allreduce_sum_op = dst_block.append_op( type='c_allreduce_sum', inputs={'X': intermediate_var_0}, outputs={'Out': Out_var}, @@ -715,6 +743,12 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): 'use_model_parallel': True }) + # copy serial op's dist_attr to dist op's dist_attr + copy_distributed_attr_for_dist_op(matmul_v2_op, dst_block, + op_dist_attr) + copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block, + op_dist_attr) + if in_dygraph_mode(): raise NotImplementedError( "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 73a2c3c7a4..bd0c666968 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -582,6 +582,8 @@ if(WITH_DISTRIBUTE) py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_lars_meta_optimizer MODULES test_fleet_lars_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_lamb_meta_optimizer MODULES test_fleet_lamb_meta_optimizer ENVS ${dist_ENVS}) + py_test_modules(test_auto_parallel_partitioner MODULES test_auto_parallel_partitioner ENVS ${dist_ENVS}) + py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS}) endif(NOT WIN32) endif(NOT APPLE) if(WITH_DGC) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py index 18dcc36fe0..29ba863c96 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py @@ -135,6 +135,127 @@ def initialization_check(mode, dist_context, dist_startup_prog, return True +def get_input_var_dist_attr(op, main_program, dist_context): + varname = op.desc.input_arg_names() + var = main_program.global_block().var(varname[0]) + dist_attr = dist_context.get_tensor_distributed_attr_for_program(var) + return dist_attr + + +def get_output_var_dist_attr(op, main_program, dist_context): + varname = op.desc.output_arg_names() + var = main_program.global_block().var(varname[0]) + dist_attr = dist_context.get_tensor_distributed_attr_for_program(var) + return dist_attr + + +def check_equal_var_dist_attr(serial_dist_attr, dist_attr): + equal = True + if serial_dist_attr.get_process_mesh() != dist_attr.get_process_mesh() or \ + serial_dist_attr.is_parameter() != dist_attr.is_parameter() or \ + serial_dist_attr.get_dims_mapping() != dist_attr.get_dims_mapping(): + equal = False + return equal + + +def check_equal_dist_op_attr(dist_context, dist_main_prog, serial_op, dist_ops, + dist_op_idx): + equal = True + # get serial op's process_mesh and impl_idx + serial_op_dist_attr = dist_context.get_op_distributed_attr_for_program( + serial_op) + serial_process_mesh = serial_op_dist_attr.get_process_mesh() + serial_impl_idx = serial_op_dist_attr.get_impl_idx() + + # check dist_attr between serial op and dist op + for i in dist_op_idx: + op_dist_attr = dist_context.get_op_distributed_attr_for_program( + dist_ops[i]) + for in_varname in dist_ops[i].desc.input_arg_names(): + in_var = dist_main_prog.global_block().var(in_varname) + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + in_var) + tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() + in_var_dims_mapping = op_dist_attr.get_input_dims_mapping( + in_varname) + if tensor_dims_mapping != in_var_dims_mapping: + equal = False + for out_varname in dist_ops[i].desc.output_arg_names(): + out_var = dist_main_prog.global_block().var(out_varname) + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + out_var) + tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() + out_var_dims_mapping = op_dist_attr.get_output_dims_mapping( + out_varname) + if tensor_dims_mapping != out_var_dims_mapping: + equal = False + + dist_op_process_mesh = op_dist_attr.get_process_mesh() + dist_op_impl_idx = op_dist_attr.get_impl_idx() + if serial_op.desc.id() == dist_ops[i].desc.id() or \ + serial_process_mesh != dist_op_process_mesh or \ + serial_impl_idx != dist_op_impl_idx: + equal = False + + return equal + + +def distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog, + dist_context, serial_op_idx, + dist_op_idx): + + equal = True + serial_ops = serial_main_prog.global_block().ops + dist_ops = dist_main_prog.global_block().ops + + for i in range(len(serial_op_idx)): + serial_op = serial_ops[serial_op_idx[i]] + dist_op_0 = dist_ops[dist_op_idx[i][0]] + if dist_op_0.type == "c_identity": + # serial op input's dist_attr + serial_in_dist_attr = get_input_var_dist_attr( + serial_op, serial_main_prog, dist_context) + # c_identity output's(new var) dist_attr + identity_out_dist_attr = get_output_var_dist_attr( + dist_op_0, dist_main_prog, dist_context) + # check var dist_attr + equal = check_equal_var_dist_attr(serial_in_dist_attr, + identity_out_dist_attr) + else: + # serial op output's dist_attr + serial_out_dist_attr = get_output_var_dist_attr( + serial_op, serial_main_prog, dist_context) + # dist op output's(new var) dist_attr + out_dist_attr = get_output_var_dist_attr(dist_op_0, dist_main_prog, + dist_context) + # check var dist_attr + equal = check_equal_var_dist_attr(serial_out_dist_attr, + out_dist_attr) + + # check op's dist_attr + equal = check_equal_dist_op_attr(dist_context, dist_main_prog, + serial_op, dist_ops, dist_op_idx[i]) + + return equal + + +def distributed_attr_check_for_program(dist_main_prog, dist_context): + have_dist_attr = True + for block in dist_main_prog.blocks: + for tensor in block.vars.values(): + var_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + tensor) + if var_dist_attr is None: + have_dist_attr = False + + for op in block.ops: + op_dist_attr = dist_context.get_op_distributed_attr_for_program(op) + if op_dist_attr is None: + have_dist_attr = False + + return have_dist_attr + + class MLPLayer(nn.Layer): def __init__(self, hidden_size=1024, @@ -276,8 +397,8 @@ class TestMLPAutoPartitioner(unittest.TestCase): dist_ops = dist_main_prog.global_block().ops dist_ops = [op.type for op in dist_ops] ref_ops = [ - 'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'gelu', - 'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout' + 'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu', + 'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout' ] self.assertTrue(dist_ops == ref_ops) @@ -289,6 +410,17 @@ class TestMLPAutoPartitioner(unittest.TestCase): dist_startup_prog, serial_startup_prog, var_need_broadcast)) + # check var and op all have dist_attr in dist_main_program + self.assertTrue( + distributed_attr_check_for_program(dist_main_prog, dist_context)) + # check distribured attr for dist op + serial_op_idx = [1, 4] + dist_op_idx = [[1, 2], [5, 6]] + self.assertTrue( + distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog, + dist_context, serial_op_idx, + dist_op_idx)) + def test_mlp_dp_mp(self): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" @@ -323,8 +455,8 @@ class TestMLPAutoPartitioner(unittest.TestCase): dist_ops = dist_main_prog.global_block().ops dist_ops = [op.type for op in dist_ops] ref_ops = [ - 'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'gelu', - 'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout' + 'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu', + 'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout' ] self.assertTrue(dist_ops == ref_ops) @@ -336,6 +468,17 @@ class TestMLPAutoPartitioner(unittest.TestCase): dist_startup_prog, serial_startup_prog, var_need_broadcast)) + # check var and op all have dist_attr in dist_main_program + self.assertTrue( + distributed_attr_check_for_program(dist_main_prog, dist_context)) + # check distribured attr for dist op + serial_op_idx = [1, 4] + dist_op_idx = [[1, 2], [5, 6]] + self.assertTrue( + distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog, + dist_context, serial_op_idx, + dist_op_idx)) + class AttentionLayer(nn.Layer): def __init__(self, @@ -531,12 +674,12 @@ class TestAttentionAutoPartitioner(unittest.TestCase): dist_ops = dist_main_prog.global_block().ops dist_ops = [op.type for op in dist_ops] ref_ops = [ - 'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', - 'c_identity', 'matmul', 'elementwise_add', 'c_identity', 'matmul', - 'elementwise_add', 'reshape2', 'transpose2', 'reshape2', - 'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2', - 'transpose2', 'reshape2', 'matmul', 'c_allreduce_sum', - 'elementwise_add' + 'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2', + 'transpose2', 'c_identity', 'matmul_v2', 'elementwise_add', + 'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2', + 'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax', + 'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2', + 'c_allreduce_sum', 'elementwise_add' ] self.assertTrue(dist_ops == ref_ops) @@ -547,6 +690,17 @@ class TestAttentionAutoPartitioner(unittest.TestCase): dist_startup_prog, serial_startup_prog, var_need_broadcast)) + # check var and op all have dist_attr in dist_main_program + self.assertTrue( + distributed_attr_check_for_program(dist_main_prog, dist_context)) + # check distribured attr for dist op + serial_op_idx = [0, 4, 6, 18] + dist_op_idx = [[0, 1], [5, 6], [8, 9], [21, 22]] + self.assertTrue( + distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog, + dist_context, serial_op_idx, + dist_op_idx)) + def test_attn_dp_mp(self): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" @@ -582,12 +736,12 @@ class TestAttentionAutoPartitioner(unittest.TestCase): dist_ops = dist_main_prog.global_block().ops dist_ops = [op.type for op in dist_ops] ref_ops = [ - 'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', - 'c_identity', 'matmul', 'elementwise_add', 'c_identity', 'matmul', - 'elementwise_add', 'reshape2', 'transpose2', 'reshape2', - 'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2', - 'transpose2', 'reshape2', 'matmul', 'c_allreduce_sum', - 'elementwise_add' + 'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2', + 'transpose2', 'c_identity', 'matmul_v2', 'elementwise_add', + 'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2', + 'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax', + 'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2', + 'c_allreduce_sum', 'elementwise_add' ] self.assertTrue(dist_ops == ref_ops) @@ -598,6 +752,17 @@ class TestAttentionAutoPartitioner(unittest.TestCase): dist_startup_prog, serial_startup_prog, var_need_broadcast)) + # check var and op all have dist_attr in dist_main_program + self.assertTrue( + distributed_attr_check_for_program(dist_main_prog, dist_context)) + # check distribured attr for dist op + serial_op_idx = [0, 4, 6, 18] + dist_op_idx = [[0, 1], [5, 6], [8, 9], [21, 22]] + self.assertTrue( + distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog, + dist_context, serial_op_idx, + dist_op_idx)) + class DecoderLayer(nn.Layer): def __init__(self, @@ -859,15 +1024,16 @@ class TestDecoderLayerPartitioner(unittest.TestCase): dist_ops = [op.type for op in dist_ops] ref_ops = [ 'c_embedding', 'c_allreduce_sum', 'lookup_table_v2', - 'elementwise_add', 'dropout', 'layer_norm', 'c_identity', 'matmul', - 'elementwise_add', 'reshape2', 'transpose2', 'c_identity', 'matmul', - 'elementwise_add', 'c_identity', 'matmul', 'elementwise_add', - 'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul', - 'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2', - 'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout', - 'elementwise_add', 'layer_norm', 'c_identity', 'matmul', - 'elementwise_add', 'gelu', 'matmul', 'c_allreduce_sum', - 'elementwise_add', 'dropout', 'elementwise_add' + 'elementwise_add', 'dropout', 'layer_norm', 'c_identity', + 'matmul_v2', 'elementwise_add', 'reshape2', 'transpose2', + 'c_identity', 'matmul_v2', 'elementwise_add', 'c_identity', + 'matmul_v2', 'elementwise_add', 'reshape2', 'transpose2', + 'reshape2', 'transpose2', 'matmul', 'softmax', 'dropout', + 'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2', + 'c_allreduce_sum', 'elementwise_add', 'dropout', 'elementwise_add', + 'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu', + 'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout', + 'elementwise_add' ] self.assertTrue(dist_ops == ref_ops) @@ -881,6 +1047,18 @@ class TestDecoderLayerPartitioner(unittest.TestCase): dist_startup_prog, serial_startup_prog, var_need_broadcast)) + # check var and op all have dist_attr in dist_main_program + self.assertTrue( + distributed_attr_check_for_program(dist_main_prog, dist_context)) + # check distribured attr + serial_op_idx = [0, 5, 9, 11, 23, 28, 31] + dist_op_idx = [[0, 1], [6, 7], [11, 12], [14, 15], [27, 28], [33, 34], + [37, 38]] + self.assertTrue( + distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog, + dist_context, serial_op_idx, + dist_op_idx)) + def test_decoder_noparallel(self): global _global_parallel_strategy _global_parallel_strategy = "None" @@ -923,13 +1101,13 @@ class TestDecoderLayerPartitioner(unittest.TestCase): dist_ops = [op.type for op in dist_ops] ref_ops = [ 'lookup_table_v2', 'lookup_table_v2', 'elementwise_add', 'dropout', - 'layer_norm', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', - 'matmul', 'elementwise_add', 'matmul', 'elementwise_add', - 'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul', - 'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2', - 'matmul', 'elementwise_add', 'dropout', 'elementwise_add', - 'layer_norm', 'matmul', 'elementwise_add', 'gelu', 'matmul', - 'elementwise_add', 'dropout', 'elementwise_add' + 'layer_norm', 'matmul_v2', 'elementwise_add', 'reshape2', + 'transpose2', 'matmul_v2', 'elementwise_add', 'matmul_v2', + 'elementwise_add', 'reshape2', 'transpose2', 'reshape2', + 'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2', + 'transpose2', 'reshape2', 'matmul_v2', 'elementwise_add', 'dropout', + 'elementwise_add', 'layer_norm', 'matmul_v2', 'elementwise_add', + 'gelu', 'matmul_v2', 'elementwise_add', 'dropout', 'elementwise_add' ] self.assertTrue(dist_ops == ref_ops) dist_ops = dist_startup_prog.global_block().ops -- GitLab