From 6f3c96438b4ac3a199195d300b38563221c661b6 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 14 Apr 2023 14:01:31 +0800 Subject: [PATCH] Eb118 BF16 Adoption (#52827) * pr1 * pr2 * pr3 * fixed unitest * adopt for scale --- .../distributed/auto_parallel/constants.py | 6 +- .../auto_parallel/operators/dist_embedding.py | 11 +- .../auto_parallel/operators/dist_matmul.py | 1680 ++++++++++------- .../auto_parallel/parallelizer_v2.py | 19 +- .../distributed/passes/auto_parallel_amp.py | 513 +++-- .../distributed/passes/auto_parallel_fp16.py | 408 ++-- .../unittests/auto_parallel/CMakeLists.txt | 3 + .../unittests/auto_parallel/amp_o2_pass.py | 142 ++ .../auto_parallel/amp_pass_unittest.py | 2 +- .../auto_parallel/test_amp_o2_pass.py | 55 + .../unittests/auto_parallel/test_strategy.py | 9 +- 11 files changed, 1878 insertions(+), 970 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 44d804a4816..19d444248fa 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False) ######################################### AMP = "amp" set_field_default_config(AMP, "enable", False) +set_field_default_config(AMP, "dtype", "float16") +set_field_default_config(AMP, "level", "o1") set_field_default_config(AMP, "init_loss_scaling", 32768.0) set_field_default_config(AMP, "incr_every_n_steps", 1000) set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2) @@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True) set_field_default_config(AMP, "custom_white_list", []) set_field_default_config(AMP, "custom_black_list", []) set_field_default_config(AMP, "custom_black_varnames", []) -set_field_default_config(AMP, "use_pure_fp16", False) -set_field_default_config(AMP, "use_fp16_guard", True) +set_field_default_config(AMP, "use_fp16_guard", False) +set_field_default_config(AMP, "use_bf16_guard", False) set_field_default_config(AMP, "use_optimizer_fp16", False) ######################################### diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 26bed30871c..5e394a12c7f 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): check_variable_and_dtype( Out_var, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'c_allreduce_sum', ) @@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): check_variable_and_dtype( Out_grad, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], '_c_identity', ) @@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): }, ) check_variable_and_dtype( - intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' + intermediate_var_0, + 'x', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', ) check_dtype( intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], + ['float16', 'float32', 'float64', 'uint16'], 'linear', ) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 8f2db1a3b26..78da90b812c 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -20,7 +20,11 @@ from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl from .common import gradient_synchronization -from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related +from .common import ( + set_comm_op_dist_attr_for_program, + naive_copy_op_dist_attr_for_program, + is_parameter_related, +) from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index @@ -33,24 +37,39 @@ from paddle.fluid import core, unique_name from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype -from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY +from paddle.distributed.fleet.meta_optimizers.common import ( + OpRole, + OP_ROLE_KEY, + OP_ROLE_VAR_KEY, +) from ..process_group import new_process_group from ..utils import _get_comm_group, _get_corresponding_rank from .dist_default import DistributedDefaultImpl0 -from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs +from ..cost import ( + build_comp_desc_from_dist_op, + build_comm_desc_from_dist_op, + build_dp_costs, +) from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost -from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost +from paddle.distributed.auto_parallel.cost.comm_op_cost import ( + AllreduceSumOpCost, + IdentityOpCost, +) def trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping): if trans_x: - x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[ - -2], x_dims_mapping[-1] + x_dims_mapping[-1], x_dims_mapping[-2] = ( + x_dims_mapping[-2], + x_dims_mapping[-1], + ) if trans_y: - y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[ - -2], y_dims_mapping[-1] + y_dims_mapping[-1], y_dims_mapping[-2] = ( + y_dims_mapping[-2], + y_dims_mapping[-1], + ) def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): @@ -123,13 +142,17 @@ def _update_dims_mapping_for_matmul(dist_op): for i in range(new_out_dims_mapping_len - 2): broadcast_out_dims_mapping.append(out_dims_mapping[i]) - compatible_dims_mapping = compute_compatible_dims_mapping([ - broadcast_x_dims_mapping, broadcast_y_dims_mapping, - broadcast_out_dims_mapping - ]) + compatible_dims_mapping = compute_compatible_dims_mapping( + [ + broadcast_x_dims_mapping, + broadcast_y_dims_mapping, + broadcast_out_dims_mapping, + ] + ) if compatible_dims_mapping is None: - trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, - y_dims_mapping) + trans_x_y_dims_mapping( + trans_x, trans_y, x_dims_mapping, y_dims_mapping + ) return False for i in range(new_x_dims_mapping_len - 2): @@ -152,17 +175,20 @@ def _update_dims_mapping_for_matmul(dist_op): # The following which uses negative index can be work # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2 dim_changed = compute_compatible_and_update_dim_mapping( - [x_dims_mapping, y_dims_mapping], [-1, -2]) + [x_dims_mapping, y_dims_mapping], [-1, -2] + ) if dim_changed: changed = True dim_changed = compute_compatible_and_update_dim_mapping( - [x_dims_mapping, out_dims_mapping], [-2, -2]) + [x_dims_mapping, out_dims_mapping], [-2, -2] + ) if dim_changed: changed = True dim_changed = compute_compatible_and_update_dim_mapping( - [y_dims_mapping, out_dims_mapping], [-1, -1]) + [y_dims_mapping, out_dims_mapping], [-1, -1] + ) if dim_changed: changed = True @@ -202,7 +228,8 @@ def _is_auto_compatible_for_matmul(dist_op): x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name)) y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name)) out_dims_mapping = copy.deepcopy( - op_dist_attr.get_output_dims_mapping(out_name)) + op_dist_attr.get_output_dims_mapping(out_name) + ) x_dims_mapping_len = len(x_dims_mapping) y_dims_mapping_len = len(y_dims_mapping) out_dims_mapping_len = len(out_dims_mapping) @@ -234,22 +261,23 @@ def _is_auto_compatible_for_matmul(dist_op): for i in range(out_dims_mapping_len - 2): broadcast_out_dims_mapping.append(out_dims_mapping[i]) - is_same = ((broadcast_x_dims_mapping == broadcast_y_dims_mapping) - and (broadcast_x_dims_mapping == broadcast_out_dims_mapping)) + is_same = (broadcast_x_dims_mapping == broadcast_y_dims_mapping) and ( + broadcast_x_dims_mapping == broadcast_out_dims_mapping + ) if not is_same: return False # The following which uses negative index can be work # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2 - is_same = (x_dims_mapping[-1] == y_dims_mapping[-2]) + is_same = x_dims_mapping[-1] == y_dims_mapping[-2] if not is_same: return False - is_same = (x_dims_mapping[-2] == out_dims_mapping[-2]) + is_same = x_dims_mapping[-2] == out_dims_mapping[-2] if not is_same: return False - is_same = (y_dims_mapping[-1] == out_dims_mapping[-1]) + is_same = y_dims_mapping[-1] == out_dims_mapping[-1] if not is_same: return False @@ -265,8 +293,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): backward_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id dist_attr = ctx.get_op_dist_attr_for_program(backward_op) - assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( - str(backward_op)) + assert ( + dist_attr is not None + ), "backward op [{}] don't have dist attribute !".format(str(backward_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in dist_attr.process_mesh.processes: @@ -277,22 +306,26 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD') assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD') assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD') - assert len( + assert ( + len(kwargs['Y']) == 1 + ), "row_parallel_embedding input Ids take 1 variable but got {}".format( kwargs['Y'] - ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['Y']) - assert len( + ) + assert ( + len(kwargs['X']) == 1 + ), "row_parallel_embedding input Ids take 1 variable but got {}".format( kwargs['X'] - ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['X']) - assert len( - kwargs['Out@GRAD'] - ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( - kwargs['Out']) - assert len( + ) + assert ( + len(kwargs['Out@GRAD']) == 1 + ), "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['Out'] + ) + assert ( + len(kwargs['Y@GRAD']) == 1 + ), "row_parallel_embedding output Ids take 1 variable but got {}".format( kwargs['Y@GRAD'] - ) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format( - kwargs['Y@GRAD']) + ) X_var = main_block.var(kwargs['X'][0]) Y_var = main_block._var_recursive(kwargs['Y'][0]) @@ -302,7 +335,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): assert not is_parameter_related( X_var.name, main_block ), "left operand(X) [{}] of dist matmul should not be parameter".format( - X_var.name) + X_var.name + ) X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name) Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name) @@ -339,28 +373,34 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): parallel_axis = Y_var_dim_mapping[0] check_variable_and_dtype( - Out_grad, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], - '_c_identity') + Out_grad, + 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + '_c_identity', + ) intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["c_identity", 'tmp'])) + "@GRAD", + name=unique_name.generate_with_ignorable_key( + ".".join(["c_identity", 'tmp']) + ) + + "@GRAD", dtype=Out_grad.dtype, shape=Out_grad.shape, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=Out_grad.stop_gradient) + stop_gradient=Out_grad.stop_gradient, + ) # copy X_var's dist_attr to intermediate_var_0's dist_attr out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name) assert out_grad_dist_attr is not None - ctx.set_tensor_dist_attr_for_program(intermediate_var_0, - out_grad_dist_attr) + ctx.set_tensor_dist_attr_for_program( + intermediate_var_0, out_grad_dist_attr + ) - group_ranks = _get_comm_group(process_mesh_group, - process_mesh_shape, parallel_axis, - rank_id) + group_ranks = _get_comm_group( + process_mesh_group, process_mesh_shape, parallel_axis, rank_id + ) group = new_process_group(group_ranks) c_identity_op = main_block.append_op( type='c_identity', @@ -371,20 +411,29 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): 'use_calc_stream': True, 'use_model_parallel': True, OP_ROLE_KEY: OpRole.Backward, - }) - check_variable_and_dtype(intermediate_var_0, 'x', - ['float16', 'float32', 'float64'], - 'linear') - check_dtype(intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], 'linear') - set_comm_op_dist_attr_for_program(c_identity_op, - dist_attr.process_mesh, - out_grad_dist_attr, ctx) + }, + ) + check_variable_and_dtype( + intermediate_var_0, + 'x', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', + ) + check_dtype( + intermediate_var_0.dtype, + 'dtype', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', + ) + set_comm_op_dist_attr_for_program( + c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx + ) new_kwargs = copy.deepcopy(kwargs) new_kwargs['Out@GRAD'] = [intermediate_var_0.name] matmul_op_desc = copy_op_with_new_input_output( - ctx, main_block, backward_op, **new_kwargs) + ctx, main_block, backward_op, **new_kwargs + ) else: # col parallel: matmul + allreduce assert Y_var_dim_mapping[0] < 0 @@ -397,28 +446,36 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): assert len(kwargs['X@GRAD']) == 1 X_grad = main_block.var(kwargs['X@GRAD'][0]) intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["c_identity", 'tmp'])) + "@GRAD", + name=unique_name.generate_with_ignorable_key( + ".".join(["c_identity", 'tmp']) + ) + + "@GRAD", dtype=X_grad.dtype, shape=X_grad.shape, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=X_grad.stop_gradient) + stop_gradient=X_grad.stop_gradient, + ) X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name) assert X_grad_dist_attr is not None - ctx.set_tensor_dist_attr_for_program(intermediate_var_0, - X_grad_dist_attr) + ctx.set_tensor_dist_attr_for_program( + intermediate_var_0, X_grad_dist_attr + ) new_kwargs['X@GRAD'] = [intermediate_var_0.name] matmul_op_desc = copy_op_with_new_input_output( - ctx, main_block, backward_op, **new_kwargs) + ctx, main_block, backward_op, **new_kwargs + ) # NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad if has_x_grad: - group_ranks = _get_comm_group(process_mesh_group, - process_mesh_shape, parallel_axis, - rank_id) + group_ranks = _get_comm_group( + process_mesh_group, + process_mesh_shape, + parallel_axis, + rank_id, + ) group = new_process_group(group_ranks) c_allreduce_sum_op = main_block.append_op( type='c_allreduce_sum', @@ -428,15 +485,20 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): 'ring_id': group.id, 'use_calc_stream': True, 'use_model_parallel': True, - OP_ROLE_KEY: OpRole.Backward - }) - set_comm_op_dist_attr_for_program(c_allreduce_sum_op, - dist_attr.process_mesh, - X_grad_dist_attr, ctx) + OP_ROLE_KEY: OpRole.Backward, + }, + ) + set_comm_op_dist_attr_for_program( + c_allreduce_sum_op, + dist_attr.process_mesh, + X_grad_dist_attr, + ctx, + ) else: # replicate - matmul_op_desc = copy_op_with_new_input_output(ctx, main_block, - backward_op, **kwargs) + matmul_op_desc = copy_op_with_new_input_output( + ctx, main_block, backward_op, **kwargs + ) # data parallel gradient synchronization act_grad_names = [X_var.name] @@ -448,8 +510,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): if trans_x: trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None) - gradient_synchronization(ctx, backward_op, act_grad_names, out_grad_names, - rank_id) + gradient_synchronization( + ctx, backward_op, act_grad_names, out_grad_names, rank_id + ) if trans_x: trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None) @@ -472,23 +535,25 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): if size <= 1 or axis in dim_mapping: pass else: - group_ranks = _get_comm_group(process_mesh.processes, - process_mesh.topology, axis, rank_id) + group_ranks = _get_comm_group( + process_mesh.processes, process_mesh.topology, axis, rank_id + ) sync_group = new_process_group(group_ranks) - startup_block.append_op(type='c_broadcast', - inputs={'X': param}, - outputs={'Out': param}, - attrs={ - 'ring_id': sync_group.id, - 'root': 0, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Forward - }) + startup_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': sync_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward, + }, + ) class DistributedMatmul(DistributedOperatorImplContainer): - def __init__(self, op_type): super(DistributedMatmul, self).__init__(op_type) @@ -498,7 +563,6 @@ register_distributed_operator_impl_container(DistributedMatmul("matmul")) # ColumnParallel class DistributedMatmulImpl0(DistributedOperatorImpl): - def __init__(self, name): super(DistributedMatmulImpl0, self).__init__(name) self._forward_implemented = True @@ -521,7 +585,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): main_block = backward_op.block vars = main_block.vars Y_var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("Y")[0]) + backward_op.input("Y")[0] + ) # col parallel: matmul + allreduce assert Y_var_dim_mapping[0] < 0 parallel_axis = Y_var_dim_mapping[1] @@ -531,13 +596,14 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): assert len(backward_op.output("X@GRAD")) == 1 # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) process_mesh = dist_attr.process_mesh processes = process_mesh.processes - cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx, - processes, desc_mapping, - cluster) + cost_mapping = build_comp_costs_from_descs( + MatmulGradOpCost, ctx, processes, desc_mapping, cluster + ) res.append(cost_mapping) # calc comm op cost @@ -550,40 +616,52 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) comm_op_cost_list = build_comm_costs_from_descs( - AllreduceSumOpCost, ctx, processes, - c_allreduce_sum_desc_mapping, cluster) + AllreduceSumOpCost, + ctx, + processes, + c_allreduce_sum_desc_mapping, + cluster, + ) res.append(comm_op_cost_list) # need gradient allreduce var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("X")[0]) + backward_op.input("X")[0] + ) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] - if batch_size_axis > -1 and mesh_shape[ - batch_size_axis] > 1 and is_parameter_related( - backward_op.input("Y")[0], main_block): + if ( + batch_size_axis > -1 + and mesh_shape[batch_size_axis] > 1 + and is_parameter_related(backward_op.input("Y")[0], main_block) + ): parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} var_names = [backward_op.output('Y@GRAD')[0]] - build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, - cluster) + build_dp_costs( + res, dist_op, ctx, var_names, attrs, parallel_axis, cluster + ) return res def calc_fwd_cost(self, dist_op, ctx, cluster): # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) processes = dist_op.dist_attr.process_mesh.processes - cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes, - desc_mapping, cluster) + cost_mapping = build_comp_costs_from_descs( + MatmulOpCost, ctx, processes, desc_mapping, cluster + ) # calc comm op cost serial_op = dist_op.serial_op vars = serial_op.block.vars parallel_axis = dist_op.dist_attr.get_input_dims_mapping( - serial_op.input("Y")[0])[-1] + serial_op.input("Y")[0] + )[-1] attrs = {"use_calc_stream": True, "use_model_parallel": True} var_names = serial_op.input("X") c_identity_desc_mapping = build_comm_desc_from_dist_op( @@ -592,10 +670,12 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) comm_op_cost_list = build_comm_costs_from_descs( - IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster + ) res_cost = [comm_op_cost_list, cost_mapping] return res_cost @@ -606,16 +686,19 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] x_dims_mapping = copy.deepcopy( - op_dist_attr.get_input_dims_mapping(x_name)) + op_dist_attr.get_input_dims_mapping(x_name) + ) y_dims_mapping = copy.deepcopy( - op_dist_attr.get_input_dims_mapping(y_name)) + op_dist_attr.get_input_dims_mapping(y_name) + ) trans_x = op_desc.attr('transpose_X') trans_y = op_desc.attr('transpose_Y') trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) if is_dim_shard(x_dims_mapping[-1]): return False if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate( - y_dims_mapping[-1]): + y_dims_mapping[-1] + ): return False for mapping in x_dims_mapping[1:-1]: if is_dim_shard(mapping): @@ -635,8 +718,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - if (not self.is_input_compatible(dist_op)) or \ - (not self.is_output_compatible(dist_op)): + if (not self.is_input_compatible(dist_op)) or ( + not self.is_output_compatible(dist_op) + ): return False if not _is_auto_compatible_for_matmul(dist_op): return False @@ -661,28 +745,33 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( - str(src_op)) + assert ( + op_dist_attr is not None + ), "backward op [{}] don't have dist attribute !".format(str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in op_dist_attr.process_mesh.processes: - rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, - rank_id) + rank_id = _get_corresponding_rank( + ctx, op_dist_attr.process_mesh, rank_id + ) # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( - input_name) + input_name + ) assert len(kwargs[input_name]) == len( src_op.desc.input(input_name) ), "number of tensor for input [{}] is not match".format(input_name) for output_name in src_op.desc.output_names(): assert output_name in kwargs, "input [{}] is not given".format( - output_name) + output_name + ) assert len(kwargs[output_name]) == len( src_op.desc.output(output_name) ), "number of tensor for input [{}] is not match".format( - output_name) + output_name + ) X_var = main_block.var(kwargs['X'][0]) Weight_var = main_block.var(kwargs['Y'][0]) @@ -692,18 +781,24 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): # TODO infer logic comm presentation matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( - Weight_var.name)[-1] + Weight_var.name + )[-1] if trans_y: matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( - Weight_var.name)[-2] - assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( - matmul_col_dim_mapping) + Weight_var.name + )[-2] + assert ( + matmul_col_dim_mapping >= 0 + ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + matmul_col_dim_mapping + ) process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_group = op_dist_attr.process_mesh.processes parallel_axis = matmul_col_dim_mapping - group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, - parallel_axis, rank_id) + group_ranks = _get_comm_group( + process_mesh_group, process_mesh_shape, parallel_axis, rank_id + ) group = new_process_group(group_ranks) # infer new var shape with op dist attr @@ -711,31 +806,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): assert x_tensor_dist_attr is not None identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name) assert identity_var_dist_attr is not None - ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr, - identity_var_dist_attr) + ref_shape_x = infer_shape( + main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr + ) # infer out var shape with op dist attr out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr, - out_var_dist_attr) + ref_shape_out = infer_shape( + main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr + ) intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["c_identity", 'tmp'])), + name=unique_name.generate_with_ignorable_key( + ".".join(["c_identity", 'tmp']) + ), dtype=X_var.dtype, shape=X_var.shape, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=X_var.stop_gradient) + stop_gradient=X_var.stop_gradient, + ) # set intermediate_var_0's dist_attr with X_var's dist_attr - ctx.set_tensor_dist_attr_for_program(intermediate_var_0, - identity_var_dist_attr) + ctx.set_tensor_dist_attr_for_program( + intermediate_var_0, identity_var_dist_attr + ) check_variable_and_dtype( - X_var, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') + X_var, + 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + '_c_identity', + ) c_identity_op = main_block.append_op( type='c_identity', @@ -745,26 +848,34 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): 'ring_id': group.id, 'use_calc_stream': True, 'use_model_parallel': True, - OP_ROLE_KEY: src_op.attr('op_role') - }) + OP_ROLE_KEY: src_op.attr('op_role'), + }, + ) if intermediate_var_0.shape != ref_shape_x: intermediate_var_0.desc.set_shape(ref_shape_x) - check_variable_and_dtype(intermediate_var_0, 'x', - ['float16', 'float32', 'float64'], 'linear') - check_dtype(intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], 'linear') + check_variable_and_dtype( + intermediate_var_0, + 'x', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', + ) + check_dtype( + intermediate_var_0.dtype, + 'dtype', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', + ) attrs = { 'transpose_X': trans_x, 'transpose_Y': trans_y, 'alpha': 1, - OP_ROLE_KEY: src_op.attr('op_role') + OP_ROLE_KEY: src_op.attr('op_role'), } inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} - matmul_op = main_block.append_op(type='matmul', - inputs=inputs, - outputs={'Out': Out_var}, - attrs=attrs) + matmul_op = main_block.append_op( + type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs + ) if Out_var.shape != ref_shape_out: Out_var.desc.set_shape(ref_shape_out) @@ -778,13 +889,16 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): input_varname = c_identity_op.desc.input_arg_names()[0] input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) assert input_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - identity_op_dist_attr.set_input_dist_attr(input_varname, - input_dist_attr) + op_dist_attr + ) + identity_op_dist_attr.set_input_dist_attr( + input_varname, input_dist_attr + ) # output output_varname = c_identity_op.desc.output_arg_names()[0] - identity_op_dist_attr.set_output_dist_attr(output_varname, - input_dist_attr) + identity_op_dist_attr.set_output_dist_attr( + output_varname, input_dist_attr + ) # set op dist attr ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr) @@ -797,31 +911,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): for input_varname in matmul_op.desc.input_arg_names(): if input_varname in src_op.desc.input_arg_names(): input_dist_attr = op_dist_attr.get_input_dist_attr( - input_varname) + input_varname + ) assert input_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - matmul_op_dist_attr.set_input_dist_attr(input_varname, - input_dist_attr) + op_dist_attr + ) + matmul_op_dist_attr.set_input_dist_attr( + input_varname, input_dist_attr + ) else: input_var = main_block.var(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( - input_var) - matmul_op_dist_attr.set_input_dist_attr(input_varname, - tensor_dist_attr) + input_var + ) + matmul_op_dist_attr.set_input_dist_attr( + input_varname, tensor_dist_attr + ) # output output_varname = matmul_op.desc.output_arg_names()[0] output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) assert output_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - matmul_op_dist_attr.set_output_dist_attr(output_varname, - output_dist_attr) + op_dist_attr + ) + matmul_op_dist_attr.set_output_dist_attr( + output_varname, output_dist_attr + ) # set op dist attr ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr) # init param sync if Weight_var.is_parameter and not op_dist_attr.is_recompute: - _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, - rank_id) + _init_param_sync( + Weight_var, dist_op_context, startup_block, ctx, rank_id + ) @staticmethod def backward(ctx, *args, **kwargs): @@ -830,7 +952,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): # RowParallel class DistributedMatmulImpl1(DistributedOperatorImpl): - def __init__(self, name): super(DistributedMatmulImpl1, self).__init__(name) self._forward_implemented = True @@ -853,7 +974,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): main_block = backward_op.block vars = main_block.vars Y_var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("Y")[0]) + backward_op.input("Y")[0] + ) assert Y_var_dim_mapping[1] < 0 parallel_axis = Y_var_dim_mapping[0] @@ -866,50 +988,60 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) process_mesh = dist_attr.process_mesh processes = process_mesh.processes comm_op_cost_list = build_comm_costs_from_descs( - IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster + ) res.append(comm_op_cost_list) # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) - cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx, - processes, desc_mapping, - cluster) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) + cost_mapping = build_comp_costs_from_descs( + MatmulGradOpCost, ctx, processes, desc_mapping, cluster + ) res.append(cost_mapping) # need gradient allreduce var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("X")[0]) + backward_op.input("X")[0] + ) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] - if batch_size_axis > -1 and mesh_shape[ - batch_size_axis] > 1 and is_parameter_related( - backward_op.input("Y")[0], main_block): + if ( + batch_size_axis > -1 + and mesh_shape[batch_size_axis] > 1 + and is_parameter_related(backward_op.input("Y")[0], main_block) + ): parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} var_names = [backward_op.output('Y@GRAD')[0]] - build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, - cluster) + build_dp_costs( + res, dist_op, ctx, var_names, attrs, parallel_axis, cluster + ) return res def calc_fwd_cost(self, dist_op, ctx, cluster): # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) processes = dist_op.dist_attr.process_mesh.processes - cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes, - desc_mapping, cluster) + cost_mapping = build_comp_costs_from_descs( + MatmulOpCost, ctx, processes, desc_mapping, cluster + ) # calc comm op cost serial_op = dist_op.serial_op vars = serial_op.block.vars parallel_axis = dist_op.dist_attr.get_input_dims_mapping( - serial_op.input("Y")[0])[-2] + serial_op.input("Y")[0] + )[-2] attrs = {"use_calc_stream": True, "use_model_parallel": True} var_names = serial_op.output("Out") @@ -919,11 +1051,16 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) comm_op_cost_list = build_comm_costs_from_descs( - AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping, - cluster) + AllreduceSumOpCost, + ctx, + processes, + c_allreduce_sum_desc_mapping, + cluster, + ) res_cost = [cost_mapping, comm_op_cost_list] @@ -935,16 +1072,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] x_dims_mapping = copy.deepcopy( - op_dist_attr.get_input_dims_mapping(x_name)) + op_dist_attr.get_input_dims_mapping(x_name) + ) y_dims_mapping = copy.deepcopy( - op_dist_attr.get_input_dims_mapping(y_name)) + op_dist_attr.get_input_dims_mapping(y_name) + ) trans_x = op_desc.attr('transpose_X') trans_y = op_desc.attr('transpose_Y') trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) if is_dim_replicate(x_dims_mapping[-1]): return False if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard( - y_dims_mapping[-1]): + y_dims_mapping[-1] + ): return False # Other dimensions must be replicate except the batch dimension for mapping in x_dims_mapping[1:-1]: @@ -966,8 +1106,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - if (not self.is_input_compatible(dist_op)) or \ - (not self.is_output_compatible(dist_op)): + if (not self.is_input_compatible(dist_op)) or ( + not self.is_output_compatible(dist_op) + ): return False if not _is_auto_compatible_for_matmul(dist_op): return False @@ -992,28 +1133,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( - str(src_op)) + assert ( + op_dist_attr is not None + ), "backward op [{}] don't have dist attribute !".format(str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in op_dist_attr.process_mesh.processes: - rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, - rank_id) + rank_id = _get_corresponding_rank( + ctx, op_dist_attr.process_mesh, rank_id + ) # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( - input_name) + input_name + ) assert len(kwargs[input_name]) == len( src_op.desc.input(input_name) ), "number of tensor for input [{}] is not match".format(input_name) for output_name in src_op.desc.output_names(): assert output_name in kwargs, "input [{}] is not given".format( - output_name) + output_name + ) assert len(kwargs[output_name]) == len( src_op.desc.output(output_name) ), "number of tensor for input [{}] is not match".format( - output_name) + output_name + ) X_var = main_block.var(kwargs['X'][0]) Weight_var = main_block.var(kwargs['Y'][0]) @@ -1023,29 +1169,40 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): # TODO infer logic comm presentation matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( - Weight_var.name)[-2] + Weight_var.name + )[-2] if trans_y: matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( - Weight_var.name)[-1] - assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( - matmul_row_dim_mapping) + Weight_var.name + )[-1] + assert ( + matmul_row_dim_mapping >= 0 + ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + matmul_row_dim_mapping + ) process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_group = op_dist_attr.process_mesh.processes parallel_axis = matmul_row_dim_mapping - group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, - parallel_axis, rank_id) + group_ranks = _get_comm_group( + process_mesh_group, process_mesh_shape, parallel_axis, rank_id + ) group = new_process_group(group_ranks) - check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'], - 'linear') - check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], - 'linear') + check_variable_and_dtype( + X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear' + ) + check_dtype( + X_var.dtype, + 'dtype', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', + ) attrs = { 'transpose_X': trans_x, 'transpose_Y': trans_y, 'alpha': 1, - OP_ROLE_KEY: src_op.attr('op_role') + OP_ROLE_KEY: src_op.attr('op_role'), } inputs = {'X': X_var, 'Y': Weight_var} @@ -1054,27 +1211,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr, - out_var_dist_attr) + ref_shape = infer_shape( + main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr + ) intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["c_allreduce_sum", 'tmp'])), + name=unique_name.generate_with_ignorable_key( + ".".join(["c_allreduce_sum", 'tmp']) + ), shape=Out_var.shape, dtype=Out_var.dtype, type=Out_var.type, lod_level=Out_var.lod_level, persistable=False, is_data=False, - need_check_feed=Out_var.desc.need_check_feed()) + need_check_feed=Out_var.desc.need_check_feed(), + ) # set intermediate_var_0's dist_attr with Out_var's dist_attr - ctx.set_tensor_dist_attr_for_program(intermediate_var_0, - out_var_dist_attr) + ctx.set_tensor_dist_attr_for_program( + intermediate_var_0, out_var_dist_attr + ) - matmul_op = main_block.append_op(type='matmul', - inputs=inputs, - outputs={'Out': intermediate_var_0}, - attrs=attrs) + matmul_op = main_block.append_op( + type='matmul', + inputs=inputs, + outputs={'Out': intermediate_var_0}, + attrs=attrs, + ) if intermediate_var_0.shape != ref_shape: intermediate_var_0.desc.set_shape(ref_shape) @@ -1086,8 +1249,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): 'ring_id': group.id, 'use_calc_stream': True, 'use_model_parallel': True, - OP_ROLE_KEY: src_op.attr('op_role') - }) + OP_ROLE_KEY: src_op.attr('op_role'), + }, + ) if Out_var.shape != ref_shape: Out_var.desc.set_shape(ref_shape) @@ -1100,15 +1264,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): for input_varname in matmul_op.desc.input_arg_names(): input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) assert input_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - matmul_op_dist_attr.set_input_dist_attr(input_varname, - input_dist_attr) + op_dist_attr + ) + matmul_op_dist_attr.set_input_dist_attr( + input_varname, input_dist_attr + ) output_varname = matmul_op.desc.output_arg_names()[0] output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert output_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - matmul_op_dist_attr.set_output_dist_attr(output_varname, - output_dist_attr) + op_dist_attr + ) + matmul_op_dist_attr.set_output_dist_attr( + output_varname, output_dist_attr + ) ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr) # allreduce @@ -1120,21 +1288,26 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): input_var = main_block.var(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) assert tensor_dist_attr is not None - allreduce_op_dist_attr.set_input_dist_attr(input_varname, - tensor_dist_attr) + allreduce_op_dist_attr.set_input_dist_attr( + input_varname, tensor_dist_attr + ) for output_varname in c_allreduce_sum_op.desc.output_arg_names(): output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) assert output_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - allreduce_op_dist_attr.set_output_dist_attr(output_varname, - output_dist_attr) - ctx.set_op_dist_attr_for_program(c_allreduce_sum_op, - allreduce_op_dist_attr) + op_dist_attr + ) + allreduce_op_dist_attr.set_output_dist_attr( + output_varname, output_dist_attr + ) + ctx.set_op_dist_attr_for_program( + c_allreduce_sum_op, allreduce_op_dist_attr + ) # init param sync if Weight_var.is_parameter and not op_dist_attr.is_recompute: - _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, - rank_id) + _init_param_sync( + Weight_var, dist_op_context, startup_block, ctx, rank_id + ) @staticmethod def backward(ctx, *args, **kwargs): @@ -1143,7 +1316,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): # ReplicateParallel class DistributedMatmulImpl2(DistributedOperatorImpl): - def __init__(self, name): super(DistributedMatmulImpl2, self).__init__(name) @@ -1164,38 +1336,45 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): vars = main_block.vars # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) process_mesh = dist_attr.process_mesh processes = process_mesh.processes - cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx, - processes, desc_mapping, - cluster) + cost_mapping = build_comp_costs_from_descs( + MatmulGradOpCost, ctx, processes, desc_mapping, cluster + ) res.append(cost_mapping) # need gradient allreduce var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("X")[0]) + backward_op.input("X")[0] + ) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] - if batch_size_axis > -1 and mesh_shape[ - batch_size_axis] > 1 and is_parameter_related( - backward_op.input("Y")[0], main_block): + if ( + batch_size_axis > -1 + and mesh_shape[batch_size_axis] > 1 + and is_parameter_related(backward_op.input("Y")[0], main_block) + ): parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} var_names = [backward_op.output('Y@GRAD')[0]] - build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, - cluster) + build_dp_costs( + res, dist_op, ctx, var_names, attrs, parallel_axis, cluster + ) return res def calc_fwd_cost(self, dist_op, ctx, cluster): # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) processes = dist_op.dist_attr.process_mesh.processes - cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes, - desc_mapping, cluster) + cost_mapping = build_comp_costs_from_descs( + MatmulOpCost, ctx, processes, desc_mapping, cluster + ) res_cost = [cost_mapping] return res_cost @@ -1211,13 +1390,15 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): if is_dim_shard(x_dims_mapping[-1]): return False if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard( - x_dims_mapping[-2]): + x_dims_mapping[-2] + ): return False if is_dim_shard(y_dims_mapping[-1]): return False if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard( - y_dims_mapping[-2]): + y_dims_mapping[-2] + ): return False return True @@ -1231,14 +1412,16 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): if is_dim_shard(out_dims_mapping[-1]): return False if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard( - out_dims_mapping[-2]): + out_dims_mapping[-2] + ): return False return True def is_auto_compatible(self, dist_op): - if (not self.is_input_compatible(dist_op)) or \ - (not self.is_output_compatible(dist_op)): + if (not self.is_input_compatible(dist_op)) or ( + not self.is_output_compatible(dist_op) + ): return False if not _is_auto_compatible_for_matmul(dist_op): @@ -1262,16 +1445,18 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) -register_distributed_operator_impl("matmul", - DistributedMatmulImpl0("column_parallel")) -register_distributed_operator_impl("matmul", - DistributedMatmulImpl1("row_parallel")) -register_distributed_operator_impl("matmul", - DistributedMatmulImpl2("replicate_parallel")) +register_distributed_operator_impl( + "matmul", DistributedMatmulImpl0("column_parallel") +) +register_distributed_operator_impl( + "matmul", DistributedMatmulImpl1("row_parallel") +) +register_distributed_operator_impl( + "matmul", DistributedMatmulImpl2("replicate_parallel") +) class DistributedMatmulV2(DistributedOperatorImplContainer): - def __init__(self, op_type): super(DistributedMatmulV2, self).__init__(op_type) @@ -1281,7 +1466,6 @@ register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2")) # ColumnParallel class DistributedMatmulV2Impl0(DistributedOperatorImpl): - def __init__(self, name): super(DistributedMatmulV2Impl0, self).__init__(name) self._forward_implemented = True @@ -1304,7 +1488,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): main_block = backward_op.block vars = main_block.vars Y_var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("Y")[0]) + backward_op.input("Y")[0] + ) process_mesh = dist_attr.process_mesh processes = process_mesh.processes # col parallel: matmul + allreduce @@ -1318,12 +1503,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): assert len(backward_op.output("X@GRAD")) == 1 # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) - cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx, - processes, desc_mapping, - cluster) + cost_mapping = build_comp_costs_from_descs( + MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster + ) res.append(cost_mapping) # calc comm op cost @@ -1336,45 +1522,55 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) comm_op_cost_list = build_comm_costs_from_descs( - AllreduceSumOpCost, ctx, processes, - c_allreduce_sum_desc_mapping, cluster) + AllreduceSumOpCost, + ctx, + processes, + c_allreduce_sum_desc_mapping, + cluster, + ) res.append(comm_op_cost_list) # need gradient allreduce process_mesh = dist_attr.process_mesh var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("X")[0]) + backward_op.input("X")[0] + ) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] - if batch_size_axis > -1 and mesh_shape[ - batch_size_axis] > 1 and is_parameter_related( - backward_op.input("Y")[0], main_block): + if ( + batch_size_axis > -1 + and mesh_shape[batch_size_axis] > 1 + and is_parameter_related(backward_op.input("Y")[0], main_block) + ): parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} var_names = [backward_op.output('Y@GRAD')[0]] - build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, - cluster) + build_dp_costs( + res, dist_op, ctx, var_names, attrs, parallel_axis, cluster + ) return res def calc_fwd_cost(self, dist_op, ctx, cluster): # calc comp op cost # TODO: trans shape if trans_x or trans_y is True - comp_desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + comp_desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) processes = dist_op.dist_attr.process_mesh.processes - comp_cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx, - processes, - comp_desc_mapping, - cluster) + comp_cost_mapping = build_comp_costs_from_descs( + MatmulV2OpCost, ctx, processes, comp_desc_mapping, cluster + ) # calc comm op cost serial_op = dist_op.serial_op vars = serial_op.block.vars parallel_axis = dist_op.dist_attr.get_input_dims_mapping( - serial_op.input("Y")[0])[-1] + serial_op.input("Y")[0] + )[-1] attrs = {"use_calc_stream": True, "use_model_parallel": True} var_names = serial_op.input("X") @@ -1384,9 +1580,11 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) comm_op_cost_list = build_comm_costs_from_descs( - IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster + ) res_cost = [comm_op_cost_list, comp_cost_mapping] return res_cost @@ -1397,16 +1595,19 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] x_dims_mapping = copy.deepcopy( - op_dist_attr.get_input_dims_mapping(x_name)) + op_dist_attr.get_input_dims_mapping(x_name) + ) y_dims_mapping = copy.deepcopy( - op_dist_attr.get_input_dims_mapping(y_name)) + op_dist_attr.get_input_dims_mapping(y_name) + ) trans_x = op_desc.attr('trans_x') trans_y = op_desc.attr('trans_y') trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) if is_dim_shard(x_dims_mapping[-1]): return False if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate( - y_dims_mapping[-1]): + y_dims_mapping[-1] + ): return False for mapping in x_dims_mapping[1:-1]: if is_dim_shard(mapping): @@ -1426,8 +1627,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - if (not self.is_input_compatible(dist_op)) or \ - (not self.is_output_compatible(dist_op)): + if (not self.is_input_compatible(dist_op)) or ( + not self.is_output_compatible(dist_op) + ): return False if not _is_auto_compatible_for_matmul(dist_op): return False @@ -1452,28 +1654,33 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( - str(src_op)) + assert ( + op_dist_attr is not None + ), "backward op [{}] don't have dist attribute !".format(str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in op_dist_attr.process_mesh.processes: - rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, - rank_id) + rank_id = _get_corresponding_rank( + ctx, op_dist_attr.process_mesh, rank_id + ) # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( - input_name) + input_name + ) assert len(kwargs[input_name]) == len( src_op.desc.input(input_name) ), "number of tensor for input [{}] is not match".format(input_name) for output_name in src_op.desc.output_names(): assert output_name in kwargs, "input [{}] is not given".format( - output_name) + output_name + ) assert len(kwargs[output_name]) == len( src_op.desc.output(output_name) ), "number of tensor for input [{}] is not match".format( - output_name) + output_name + ) X_var = main_block.var(kwargs['X'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0]) @@ -1483,18 +1690,24 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): # TODO infer logic comm presentation matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( - Weight_var.name)[-1] + Weight_var.name + )[-1] if trans_y: matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( - Weight_var.name)[-2] - assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( - matmul_col_dim_mapping) + Weight_var.name + )[-2] + assert ( + matmul_col_dim_mapping >= 0 + ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + matmul_col_dim_mapping + ) process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_group = op_dist_attr.process_mesh.processes parallel_axis = matmul_col_dim_mapping - group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, - parallel_axis, rank_id) + group_ranks = _get_comm_group( + process_mesh_group, process_mesh_shape, parallel_axis, rank_id + ) group = new_process_group(group_ranks) # infer new var shape with op dist attr @@ -1502,31 +1715,39 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): assert x_tensor_dist_attr is not None identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name) assert identity_var_dist_attr is not None - ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr, - identity_var_dist_attr) + ref_shape_x = infer_shape( + main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr + ) # infer out var shape with op dist attr out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr, - out_var_dist_attr) + ref_shape_out = infer_shape( + main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr + ) intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["c_identity", 'tmp'])), + name=unique_name.generate_with_ignorable_key( + ".".join(["c_identity", 'tmp']) + ), dtype=X_var.dtype, shape=X_var.shape, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=X_var.stop_gradient) + stop_gradient=X_var.stop_gradient, + ) # set intermediate_var_0's dist_attr with X_var's dist_attr - ctx.set_tensor_dist_attr_for_program(intermediate_var_0, - identity_var_dist_attr) + ctx.set_tensor_dist_attr_for_program( + intermediate_var_0, identity_var_dist_attr + ) check_variable_and_dtype( - X_var, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') + X_var, + 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + '_c_identity', + ) c_identity_op = main_block.append_op( type='c_identity', inputs={'X': [X_var]}, @@ -1536,24 +1757,35 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): 'use_calc_stream': True, 'use_model_parallel': True, OP_ROLE_KEY: src_op.attr('op_role'), - }) + }, + ) if intermediate_var_0.shape != ref_shape_x: intermediate_var_0.desc.set_shape(ref_shape_x) - check_variable_and_dtype(intermediate_var_0, 'x', - ['float16', 'float32', 'float64'], 'linear') - check_dtype(intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], 'linear') + check_variable_and_dtype( + intermediate_var_0, + 'x', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', + ) + check_dtype( + intermediate_var_0.dtype, + 'dtype', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', + ) attrs = { 'trans_x': trans_x, 'trans_y': trans_y, - OP_ROLE_KEY: src_op.attr('op_role') + OP_ROLE_KEY: src_op.attr('op_role'), } inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} - matmul_v2_op = main_block.append_op(type='matmul_v2', - inputs=inputs, - outputs={'Out': Out_var}, - attrs=attrs) + matmul_v2_op = main_block.append_op( + type='matmul_v2', + inputs=inputs, + outputs={'Out': Out_var}, + attrs=attrs, + ) if Out_var.shape != ref_shape_out: Out_var.desc.set_shape(ref_shape_out) @@ -1567,13 +1799,16 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): input_varname = c_identity_op.desc.input_arg_names()[0] input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) assert input_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - identity_op_dist_attr.set_input_dist_attr(input_varname, - input_dist_attr) + op_dist_attr + ) + identity_op_dist_attr.set_input_dist_attr( + input_varname, input_dist_attr + ) # output output_varname = c_identity_op.desc.output_arg_names()[0] - identity_op_dist_attr.set_output_dist_attr(output_varname, - input_dist_attr) + identity_op_dist_attr.set_output_dist_attr( + output_varname, input_dist_attr + ) ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr) # matmulv2 @@ -1584,29 +1819,37 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): for input_varname in matmul_v2_op.desc.input_arg_names(): if input_varname in src_op.desc.input_arg_names(): input_dist_attr = op_dist_attr.get_input_dist_attr( - input_varname) + input_varname + ) assert input_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) + op_dist_attr + ) matmulv2_op_dist_attr.set_input_dist_attr( - input_varname, input_dist_attr) + input_varname, input_dist_attr + ) else: input_var = main_block.var(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( - input_var) + input_var + ) matmulv2_op_dist_attr.set_input_dist_attr( - input_varname, tensor_dist_attr) + input_varname, tensor_dist_attr + ) for output_varname in matmul_v2_op.desc.output_arg_names(): output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) assert output_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - matmulv2_op_dist_attr.set_output_dist_attr(output_varname, - output_dist_attr) + op_dist_attr + ) + matmulv2_op_dist_attr.set_output_dist_attr( + output_varname, output_dist_attr + ) ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr) # init param sync if Weight_var.is_parameter and not op_dist_attr.is_recompute: - _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, - rank_id) + _init_param_sync( + Weight_var, dist_op_context, startup_block, ctx, rank_id + ) @staticmethod def backward(ctx, *args, **kwargs): @@ -1615,7 +1858,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): # RowParallel class DistributedMatmulV2Impl1(DistributedOperatorImpl): - def __init__(self, name): super(DistributedMatmulV2Impl1, self).__init__(name) self._forward_implemented = True @@ -1638,7 +1880,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): main_block = backward_op.block vars = main_block.vars Y_var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("Y")[0]) + backward_op.input("Y")[0] + ) assert Y_var_dim_mapping[1] < 0 parallel_axis = Y_var_dim_mapping[0] @@ -1653,50 +1896,59 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) comm_op_cost_list = build_comm_costs_from_descs( - IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster + ) res.append(comm_op_cost_list) # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) - cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx, - processes, desc_mapping, - cluster) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) + cost_mapping = build_comp_costs_from_descs( + MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster + ) res.append(cost_mapping) # need gradient allreduce process_mesh = dist_attr.process_mesh var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("X")[0]) + backward_op.input("X")[0] + ) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] - if batch_size_axis > -1 and mesh_shape[ - batch_size_axis] > 1 and is_parameter_related( - backward_op.input("Y")[0], main_block): + if ( + batch_size_axis > -1 + and mesh_shape[batch_size_axis] > 1 + and is_parameter_related(backward_op.input("Y")[0], main_block) + ): parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} var_names = [backward_op.output('Y@GRAD')[0]] - build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, - cluster) + build_dp_costs( + res, dist_op, ctx, var_names, attrs, parallel_axis, cluster + ) return res def calc_fwd_cost(self, dist_op, ctx, cluster): # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) processes = dist_op.dist_attr.process_mesh.processes - cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx, - processes, desc_mapping, - cluster) + cost_mapping = build_comp_costs_from_descs( + MatmulV2OpCost, ctx, processes, desc_mapping, cluster + ) # calc comm op cost serial_op = dist_op.serial_op vars = serial_op.block.vars parallel_axis = dist_op.dist_attr.get_input_dims_mapping( - serial_op.input("Y")[0])[-2] + serial_op.input("Y")[0] + )[-2] attrs = {"use_calc_stream": True, "use_model_parallel": True} var_names = serial_op.output("Out") @@ -1706,11 +1958,16 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) comm_op_cost_list = build_comm_costs_from_descs( - AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping, - cluster) + AllreduceSumOpCost, + ctx, + processes, + c_allreduce_sum_desc_mapping, + cluster, + ) res_cost = [cost_mapping, comm_op_cost_list] return res_cost @@ -1721,16 +1978,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] x_dims_mapping = copy.deepcopy( - op_dist_attr.get_input_dims_mapping(x_name)) + op_dist_attr.get_input_dims_mapping(x_name) + ) y_dims_mapping = copy.deepcopy( - op_dist_attr.get_input_dims_mapping(y_name)) + op_dist_attr.get_input_dims_mapping(y_name) + ) trans_x = op_desc.attr('trans_x') trans_y = op_desc.attr('trans_y') trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) if is_dim_replicate(x_dims_mapping[-1]): return False if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard( - y_dims_mapping[-1]): + y_dims_mapping[-1] + ): return False # Other dimensions must be replicate except the batch dimension for mapping in x_dims_mapping[1:-1]: @@ -1752,8 +2012,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - if (not self.is_input_compatible(dist_op)) or \ - (not self.is_output_compatible(dist_op)): + if (not self.is_input_compatible(dist_op)) or ( + not self.is_output_compatible(dist_op) + ): return False if not _is_auto_compatible_for_matmul(dist_op): return False @@ -1778,28 +2039,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( - str(src_op)) + assert ( + op_dist_attr is not None + ), "backward op [{}] don't have dist attribute !".format(str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in op_dist_attr.process_mesh.processes: - rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, - rank_id) + rank_id = _get_corresponding_rank( + ctx, op_dist_attr.process_mesh, rank_id + ) # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( - input_name) + input_name + ) assert len(kwargs[input_name]) == len( src_op.desc.input(input_name) ), "number of tensor for input [{}] is not match".format(input_name) for output_name in src_op.desc.output_names(): assert output_name in kwargs, "input [{}] is not given".format( - output_name) + output_name + ) assert len(kwargs[output_name]) == len( src_op.desc.output(output_name) ), "number of tensor for input [{}] is not match".format( - output_name) + output_name + ) X_var = main_block.var(kwargs['X'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0]) @@ -1809,28 +2075,39 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): # TODO infer logic comm presentation matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( - Weight_var.name)[-2] + Weight_var.name + )[-2] if trans_y: matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( - Weight_var.name)[-1] - assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( - matmul_row_dim_mapping) + Weight_var.name + )[-1] + assert ( + matmul_row_dim_mapping >= 0 + ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + matmul_row_dim_mapping + ) process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_group = op_dist_attr.process_mesh.processes parallel_axis = matmul_row_dim_mapping - group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, - parallel_axis, rank_id) + group_ranks = _get_comm_group( + process_mesh_group, process_mesh_shape, parallel_axis, rank_id + ) group = new_process_group(group_ranks) - check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'], - 'linear') - check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], - 'linear') + check_variable_and_dtype( + X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear' + ) + check_dtype( + X_var.dtype, + 'dtype', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', + ) attrs = { 'trans_x': trans_x, 'trans_y': trans_y, - OP_ROLE_KEY: src_op.attr('op_role') + OP_ROLE_KEY: src_op.attr('op_role'), } inputs = {'X': X_var, 'Y': Weight_var} @@ -1839,27 +2116,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr, - out_var_dist_attr) + ref_shape = infer_shape( + main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr + ) intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["c_allreduce_sum", 'tmp'])), + name=unique_name.generate_with_ignorable_key( + ".".join(["c_allreduce_sum", 'tmp']) + ), shape=Out_var.shape, dtype=Out_var.dtype, type=Out_var.type, lod_level=Out_var.lod_level, persistable=False, is_data=False, - need_check_feed=Out_var.desc.need_check_feed()) + need_check_feed=Out_var.desc.need_check_feed(), + ) # set intermediate_var_0's dist_attr with Out_var's dist_attr - ctx.set_tensor_dist_attr_for_program(intermediate_var_0, - out_var_dist_attr) + ctx.set_tensor_dist_attr_for_program( + intermediate_var_0, out_var_dist_attr + ) - matmul_v2_op = main_block.append_op(type='matmul_v2', - inputs=inputs, - outputs={'Out': intermediate_var_0}, - attrs=attrs) + matmul_v2_op = main_block.append_op( + type='matmul_v2', + inputs=inputs, + outputs={'Out': intermediate_var_0}, + attrs=attrs, + ) if intermediate_var_0.shape != ref_shape: intermediate_var_0.desc.set_shape(ref_shape) @@ -1871,8 +2154,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): 'ring_id': group.id, 'use_calc_stream': True, 'use_model_parallel': True, - OP_ROLE_KEY: src_op.attr('op_role') - }) + OP_ROLE_KEY: src_op.attr('op_role'), + }, + ) if Out_var.shape != ref_shape: Out_var.desc.set_shape(ref_shape) @@ -1885,15 +2169,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): for input_varname in matmul_v2_op.desc.input_arg_names(): input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) assert input_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - matmulv2_op_dist_attr.set_input_dist_attr(input_varname, - input_dist_attr) + op_dist_attr + ) + matmulv2_op_dist_attr.set_input_dist_attr( + input_varname, input_dist_attr + ) output_varname = matmul_v2_op.desc.output_arg_names()[0] output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert output_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - matmulv2_op_dist_attr.set_output_dist_attr(output_varname, - output_dist_attr) + op_dist_attr + ) + matmulv2_op_dist_attr.set_output_dist_attr( + output_varname, output_dist_attr + ) ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr) # allreduce @@ -1905,21 +2193,26 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): input_var = main_block.var(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) assert tensor_dist_attr is not None - allreduce_op_dist_attr.set_input_dist_attr(input_varname, - tensor_dist_attr) + allreduce_op_dist_attr.set_input_dist_attr( + input_varname, tensor_dist_attr + ) for output_varname in c_allreduce_sum_op.desc.output_arg_names(): output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) assert output_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - allreduce_op_dist_attr.set_output_dist_attr(output_varname, - output_dist_attr) - ctx.set_op_dist_attr_for_program(c_allreduce_sum_op, - allreduce_op_dist_attr) + op_dist_attr + ) + allreduce_op_dist_attr.set_output_dist_attr( + output_varname, output_dist_attr + ) + ctx.set_op_dist_attr_for_program( + c_allreduce_sum_op, allreduce_op_dist_attr + ) # init param sync if Weight_var.is_parameter and not op_dist_attr.is_recompute: - _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, - rank_id) + _init_param_sync( + Weight_var, dist_op_context, startup_block, ctx, rank_id + ) @staticmethod def backward(ctx, *args, **kwargs): @@ -1928,7 +2221,6 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): # ReplicateParallel class DistributedMatmulV2Impl2(DistributedOperatorImpl): - def __init__(self, name): super(DistributedMatmulV2Impl2, self).__init__(name) @@ -1950,38 +2242,44 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): process_mesh = dist_attr.process_mesh # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) processes = process_mesh.processes - cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx, - processes, desc_mapping, - cluster) + cost_mapping = build_comp_costs_from_descs( + MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster + ) res.append(cost_mapping) # need gradient allreduce var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("X")[0]) + backward_op.input("X")[0] + ) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] - if batch_size_axis > -1 and mesh_shape[ - batch_size_axis] > 1 and is_parameter_related( - backward_op.input("Y")[0], main_block): + if ( + batch_size_axis > -1 + and mesh_shape[batch_size_axis] > 1 + and is_parameter_related(backward_op.input("Y")[0], main_block) + ): parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} var_names = [backward_op.output('Y@GRAD')[0]] - build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, - cluster) + build_dp_costs( + res, dist_op, ctx, var_names, attrs, parallel_axis, cluster + ) return res def calc_fwd_cost(self, dist_op, ctx, cluster): # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) processes = dist_op.dist_attr.process_mesh.processes - cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx, - processes, desc_mapping, - cluster) + cost_mapping = build_comp_costs_from_descs( + MatmulV2OpCost, ctx, processes, desc_mapping, cluster + ) res_cost = [cost_mapping] @@ -1998,13 +2296,15 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): if is_dim_shard(x_dims_mapping[-1]): return False if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard( - x_dims_mapping[-2]): + x_dims_mapping[-2] + ): return False if is_dim_shard(y_dims_mapping[-1]): return False if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard( - y_dims_mapping[-2]): + y_dims_mapping[-2] + ): return False return True @@ -2019,14 +2319,16 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): if is_dim_shard(out_dims_mapping[-1]): return False if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard( - out_dims_mapping[-2]): + out_dims_mapping[-2] + ): return False return True def is_auto_compatible(self, dist_op): - if (not self.is_input_compatible(dist_op)) or \ - (not self.is_output_compatible(dist_op)): + if (not self.is_input_compatible(dist_op)) or ( + not self.is_output_compatible(dist_op) + ): return False if not _is_auto_compatible_for_matmul(dist_op): @@ -2050,16 +2352,18 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) -register_distributed_operator_impl("matmul_v2", - DistributedMatmulV2Impl0("column_parallel")) -register_distributed_operator_impl("matmul_v2", - DistributedMatmulV2Impl1("row_parallel")) register_distributed_operator_impl( - "matmul_v2", DistributedMatmulV2Impl2("replicate_parallel")) + "matmul_v2", DistributedMatmulV2Impl0("column_parallel") +) +register_distributed_operator_impl( + "matmul_v2", DistributedMatmulV2Impl1("row_parallel") +) +register_distributed_operator_impl( + "matmul_v2", DistributedMatmulV2Impl2("replicate_parallel") +) class DistributedMul(DistributedOperatorImplContainer): - def __init__(self, op_type): super(DistributedMul, self).__init__(op_type) @@ -2069,7 +2373,6 @@ register_distributed_operator_impl_container(DistributedMul("mul")) # ColumnParallel class DistributedMulImpl0(DistributedOperatorImpl): - def __init__(self, name): super(DistributedMulImpl0, self).__init__(name) self._forward_implemented = True @@ -2092,7 +2395,8 @@ class DistributedMulImpl0(DistributedOperatorImpl): main_block = backward_op.block vars = main_block.vars Y_var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("Y")[0]) + backward_op.input("Y")[0] + ) # col parallel: matmul + allreduce assert Y_var_dim_mapping[0] < 0 parallel_axis = Y_var_dim_mapping[1] @@ -2102,13 +2406,14 @@ class DistributedMulImpl0(DistributedOperatorImpl): assert len(backward_op.output("X@GRAD")) == 1 # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) process_mesh = dist_attr.process_mesh processes = process_mesh.processes - cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx, - processes, desc_mapping, - cluster) + cost_mapping = build_comp_costs_from_descs( + MulGradOpCost, ctx, processes, desc_mapping, cluster + ) res.append(cost_mapping) # calc comm op cost @@ -2121,40 +2426,52 @@ class DistributedMulImpl0(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) comm_op_cost_list = build_comm_costs_from_descs( - AllreduceSumOpCost, ctx, processes, - c_allreduce_sum_desc_mapping, cluster) + AllreduceSumOpCost, + ctx, + processes, + c_allreduce_sum_desc_mapping, + cluster, + ) res.append(comm_op_cost_list) # need gradient allreduce var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("X")[0]) + backward_op.input("X")[0] + ) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] - if batch_size_axis > -1 and mesh_shape[ - batch_size_axis] > 1 and is_parameter_related( - backward_op.input("Y")[0], main_block): + if ( + batch_size_axis > -1 + and mesh_shape[batch_size_axis] > 1 + and is_parameter_related(backward_op.input("Y")[0], main_block) + ): parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} var_names = [backward_op.output('Y@GRAD')[0]] - build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, - cluster) + build_dp_costs( + res, dist_op, ctx, var_names, attrs, parallel_axis, cluster + ) return res def calc_fwd_cost(self, dist_op, ctx, cluster): # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) processes = dist_op.dist_attr.process_mesh.processes - cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes, - desc_mapping, cluster) + cost_mapping = build_comp_costs_from_descs( + MulOpCost, ctx, processes, desc_mapping, cluster + ) # calc comm op cost serial_op = dist_op.serial_op vars = serial_op.block.vars parallel_axis = dist_op.dist_attr.get_input_dims_mapping( - serial_op.input("Y")[0])[-1] + serial_op.input("Y")[0] + )[-1] attrs = {"use_calc_stream": True, "use_model_parallel": True} var_names = serial_op.input("X") c_identity_desc_mapping = build_comm_desc_from_dist_op( @@ -2163,10 +2480,12 @@ class DistributedMulImpl0(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) comm_op_cost_list = build_comm_costs_from_descs( - IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster + ) res_cost = [comm_op_cost_list, cost_mapping] return res_cost @@ -2181,7 +2500,8 @@ class DistributedMulImpl0(DistributedOperatorImpl): if is_dim_shard(x_dims_mapping[-1]): return False if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate( - y_dims_mapping[-1]): + y_dims_mapping[-1] + ): return False for mapping in x_dims_mapping[1:-1]: if is_dim_shard(mapping): @@ -2201,8 +2521,9 @@ class DistributedMulImpl0(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - if (not self.is_input_compatible(dist_op)) or \ - (not self.is_output_compatible(dist_op)): + if (not self.is_input_compatible(dist_op)) or ( + not self.is_output_compatible(dist_op) + ): return False if not _is_auto_compatible_for_matmul(dist_op): @@ -2229,28 +2550,33 @@ class DistributedMulImpl0(DistributedOperatorImpl): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( - str(src_op)) + assert ( + op_dist_attr is not None + ), "backward op [{}] don't have dist attribute !".format(str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in op_dist_attr.process_mesh.processes: - rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, - rank_id) + rank_id = _get_corresponding_rank( + ctx, op_dist_attr.process_mesh, rank_id + ) # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( - input_name) + input_name + ) assert len(kwargs[input_name]) == len( src_op.desc.input(input_name) ), "number of tensor for input [{}] is not match".format(input_name) for output_name in src_op.desc.output_names(): assert output_name in kwargs, "input [{}] is not given".format( - output_name) + output_name + ) assert len(kwargs[output_name]) == len( src_op.desc.output(output_name) ), "number of tensor for input [{}] is not match".format( - output_name) + output_name + ) X_var = main_block.var(kwargs['X'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0]) @@ -2258,15 +2584,20 @@ class DistributedMulImpl0(DistributedOperatorImpl): # TODO infer logic comm presentation matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( - Weight_var.name)[-1] - assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( - matmul_col_dim_mapping) + Weight_var.name + )[-1] + assert ( + matmul_col_dim_mapping >= 0 + ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + matmul_col_dim_mapping + ) process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_group = op_dist_attr.process_mesh.processes parallel_axis = matmul_col_dim_mapping - group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, - parallel_axis, rank_id) + group_ranks = _get_comm_group( + process_mesh_group, process_mesh_shape, parallel_axis, rank_id + ) group = new_process_group(group_ranks) # infer new var shape with op dist attr @@ -2274,31 +2605,39 @@ class DistributedMulImpl0(DistributedOperatorImpl): assert x_tensor_dist_attr is not None identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name) assert identity_var_dist_attr is not None - ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr, - identity_var_dist_attr) + ref_shape_x = infer_shape( + main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr + ) # infer out var shape with op dist attr out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr, - out_var_dist_attr) + ref_shape_out = infer_shape( + main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr + ) intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["c_identity", 'tmp'])), + name=unique_name.generate_with_ignorable_key( + ".".join(["c_identity", 'tmp']) + ), dtype=X_var.dtype, shape=X_var.shape, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=X_var.stop_gradient) + stop_gradient=X_var.stop_gradient, + ) # set intermediate_var_0's dist_attr with X_var's dist_attr - ctx.set_tensor_dist_attr_for_program(intermediate_var_0, - identity_var_dist_attr) + ctx.set_tensor_dist_attr_for_program( + intermediate_var_0, identity_var_dist_attr + ) check_variable_and_dtype( - X_var, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') + X_var, + 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], + '_c_identity', + ) c_identity_op = main_block.append_op( type='c_identity', inputs={'X': [X_var]}, @@ -2307,20 +2646,29 @@ class DistributedMulImpl0(DistributedOperatorImpl): 'ring_id': group.id, 'use_calc_stream': True, 'use_model_parallel': True, - OP_ROLE_KEY: src_op.attr('op_role') - }) + OP_ROLE_KEY: src_op.attr('op_role'), + }, + ) if intermediate_var_0.shape != ref_shape_x: intermediate_var_0.desc.set_shape(ref_shape_x) - check_variable_and_dtype(intermediate_var_0, 'x', - ['float16', 'float32', 'float64'], 'linear') - check_dtype(intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], 'linear') + check_variable_and_dtype( + intermediate_var_0, + 'x', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', + ) + check_dtype( + intermediate_var_0.dtype, + 'dtype', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', + ) # attrs = {'trans_x': False, 'trans_y': False} attrs = { "x_num_col_dims": src_op.desc.attr("x_num_col_dims"), "y_num_col_dims": src_op.desc.attr("y_num_col_dims"), - OP_ROLE_KEY: src_op.attr('op_role') + OP_ROLE_KEY: src_op.attr('op_role'), } inputs = {'X': intermediate_var_0, 'Y': Weight_var} @@ -2334,16 +2682,15 @@ class DistributedMulImpl0(DistributedOperatorImpl): inputs_original_shape[var_name] = var.shape input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var) input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name) - input_ref_shape = infer_shape(main_block, var, - input_tensor_dist_attr, - input_var_dist_attr) + input_ref_shape = infer_shape( + main_block, var, input_tensor_dist_attr, input_var_dist_attr + ) inputs_ref_shape[var_name] = input_ref_shape var.desc.set_shape(input_ref_shape) - mul_op = main_block.append_op(type='mul', - inputs=inputs, - outputs={'Out': Out_var}, - attrs=attrs) + mul_op = main_block.append_op( + type='mul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs + ) if Out_var.shape != ref_shape_out: Out_var.desc.set_shape(ref_shape_out) @@ -2362,13 +2709,16 @@ class DistributedMulImpl0(DistributedOperatorImpl): input_varname = c_identity_op.desc.input_arg_names()[0] input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) assert input_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - identity_op_dist_attr.set_input_dist_attr(input_varname, - input_dist_attr) + op_dist_attr + ) + identity_op_dist_attr.set_input_dist_attr( + input_varname, input_dist_attr + ) # output output_varname = c_identity_op.desc.output_arg_names()[0] - identity_op_dist_attr.set_output_dist_attr(output_varname, - input_dist_attr) + identity_op_dist_attr.set_output_dist_attr( + output_varname, input_dist_attr + ) ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr) # matmulv2 @@ -2379,29 +2729,37 @@ class DistributedMulImpl0(DistributedOperatorImpl): for input_varname in mul_op.desc.input_arg_names(): if input_varname in src_op.desc.input_arg_names(): input_dist_attr = op_dist_attr.get_input_dist_attr( - input_varname) + input_varname + ) assert input_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) + op_dist_attr + ) matmulv2_op_dist_attr.set_input_dist_attr( - input_varname, input_dist_attr) + input_varname, input_dist_attr + ) else: input_var = main_block.var(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( - input_var) + input_var + ) matmulv2_op_dist_attr.set_input_dist_attr( - input_varname, tensor_dist_attr) + input_varname, tensor_dist_attr + ) for output_varname in mul_op.desc.output_arg_names(): output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) assert output_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - matmulv2_op_dist_attr.set_output_dist_attr(output_varname, - output_dist_attr) + op_dist_attr + ) + matmulv2_op_dist_attr.set_output_dist_attr( + output_varname, output_dist_attr + ) ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr) # init param sync if Weight_var.is_parameter and not op_dist_attr.is_recompute: - _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, - rank_id) + _init_param_sync( + Weight_var, dist_op_context, startup_block, ctx, rank_id + ) @staticmethod def backward(ctx, *args, **kwargs): @@ -2410,7 +2768,6 @@ class DistributedMulImpl0(DistributedOperatorImpl): # RowParallel class DistributedMulImpl1(DistributedOperatorImpl): - def __init__(self, name): super(DistributedMulImpl1, self).__init__(name) self._forward_implemented = True @@ -2434,7 +2791,8 @@ class DistributedMulImpl1(DistributedOperatorImpl): main_block = backward_op.block vars = main_block.vars Y_var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("Y")[0]) + backward_op.input("Y")[0] + ) assert Y_var_dim_mapping[1] < 0 parallel_axis = Y_var_dim_mapping[0] @@ -2447,49 +2805,59 @@ class DistributedMulImpl1(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) processes = process_mesh.processes comm_op_cost_list = build_comm_costs_from_descs( - IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) + IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster + ) res.append(comm_op_cost_list) # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) - cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx, - processes, desc_mapping, - cluster) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) + cost_mapping = build_comp_costs_from_descs( + MulGradOpCost, ctx, processes, desc_mapping, cluster + ) res.append(cost_mapping) # need gradient allreduce var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("X")[0]) + backward_op.input("X")[0] + ) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] - if batch_size_axis > -1 and mesh_shape[ - batch_size_axis] > 1 and is_parameter_related( - backward_op.input("Y")[0], main_block): + if ( + batch_size_axis > -1 + and mesh_shape[batch_size_axis] > 1 + and is_parameter_related(backward_op.input("Y")[0], main_block) + ): parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} var_names = [backward_op.output('Y@GRAD')[0]] - build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, - cluster) + build_dp_costs( + res, dist_op, ctx, var_names, attrs, parallel_axis, cluster + ) return res def calc_fwd_cost(self, dist_op, ctx, cluster): # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) processes = dist_op.dist_attr.process_mesh.processes - cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes, - desc_mapping, cluster) + cost_mapping = build_comp_costs_from_descs( + MulOpCost, ctx, processes, desc_mapping, cluster + ) # calc comm op cost serial_op = dist_op.serial_op vars = serial_op.block.vars parallel_axis = dist_op.dist_attr.get_input_dims_mapping( - serial_op.input("Y")[0])[-2] + serial_op.input("Y")[0] + )[-2] attrs = {"use_calc_stream": True, "use_model_parallel": True} var_names = serial_op.output("Out") @@ -2499,12 +2867,17 @@ class DistributedMulImpl1(DistributedOperatorImpl): ctx, var_names, attrs=attrs, - parallel_axis=parallel_axis) + parallel_axis=parallel_axis, + ) # print("dist_matmul.py dist_op: ", dist_op) comm_op_cost_list = build_comm_costs_from_descs( - AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping, - cluster) + AllreduceSumOpCost, + ctx, + processes, + c_allreduce_sum_desc_mapping, + cluster, + ) res_cost = [cost_mapping, comm_op_cost_list] @@ -2520,7 +2893,8 @@ class DistributedMulImpl1(DistributedOperatorImpl): if is_dim_replicate(x_dims_mapping[-1]): return False if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard( - y_dims_mapping[-1]): + y_dims_mapping[-1] + ): return False # Other dimensions must be replicate except the batch dimension for mapping in x_dims_mapping[1:-1]: @@ -2542,8 +2916,9 @@ class DistributedMulImpl1(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - if (not self.is_input_compatible(dist_op)) or \ - (not self.is_output_compatible(dist_op)): + if (not self.is_input_compatible(dist_op)) or ( + not self.is_output_compatible(dist_op) + ): return False if not _is_auto_compatible_for_matmul(dist_op): @@ -2570,28 +2945,33 @@ class DistributedMulImpl1(DistributedOperatorImpl): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( - str(src_op)) + assert ( + op_dist_attr is not None + ), "backward op [{}] don't have dist attribute !".format(str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in op_dist_attr.process_mesh.processes: - rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, - rank_id) + rank_id = _get_corresponding_rank( + ctx, op_dist_attr.process_mesh, rank_id + ) # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( - input_name) + input_name + ) assert len(kwargs[input_name]) == len( src_op.desc.input(input_name) ), "number of tensor for input [{}] is not match".format(input_name) for output_name in src_op.desc.output_names(): assert output_name in kwargs, "input [{}] is not given".format( - output_name) + output_name + ) assert len(kwargs[output_name]) == len( src_op.desc.output(output_name) ), "number of tensor for input [{}] is not match".format( - output_name) + output_name + ) X_var = main_block.var(kwargs['X'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0]) @@ -2599,26 +2979,36 @@ class DistributedMulImpl1(DistributedOperatorImpl): # TODO infer logic comm presentation matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( - Weight_var.name)[-2] - assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( - matmul_row_dim_mapping) + Weight_var.name + )[-2] + assert ( + matmul_row_dim_mapping >= 0 + ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + matmul_row_dim_mapping + ) process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_group = op_dist_attr.process_mesh.processes parallel_axis = matmul_row_dim_mapping - group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, - parallel_axis, rank_id) + group_ranks = _get_comm_group( + process_mesh_group, process_mesh_shape, parallel_axis, rank_id + ) group = new_process_group(group_ranks) - check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'], - 'linear') - check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], - 'linear') + check_variable_and_dtype( + X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear' + ) + check_dtype( + X_var.dtype, + 'dtype', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', + ) # attrs = {'trans_x': False, 'trans_y': False} attrs = { "x_num_col_dims": src_op.desc.attr("x_num_col_dims"), "y_num_col_dims": src_op.desc.attr("y_num_col_dims"), - OP_ROLE_KEY: src_op.attr('op_role') + OP_ROLE_KEY: src_op.attr('op_role'), } inputs = {'X': X_var, 'Y': Weight_var} @@ -2627,22 +3017,26 @@ class DistributedMulImpl1(DistributedOperatorImpl): assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr, - out_var_dist_attr) + ref_shape = infer_shape( + main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr + ) intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["c_allreduce_sum", 'tmp'])), + name=unique_name.generate_with_ignorable_key( + ".".join(["c_allreduce_sum", 'tmp']) + ), shape=Out_var.shape, dtype=Out_var.dtype, type=Out_var.type, lod_level=Out_var.lod_level, persistable=False, is_data=False, - need_check_feed=Out_var.desc.need_check_feed()) + need_check_feed=Out_var.desc.need_check_feed(), + ) # set intermediate_var_0's dist_attr with Out_var's dist_attr - ctx.set_tensor_dist_attr_for_program(intermediate_var_0, - out_var_dist_attr) + ctx.set_tensor_dist_attr_for_program( + intermediate_var_0, out_var_dist_attr + ) inputs_ref_shape = {} inputs_original_shape = {} @@ -2651,16 +3045,18 @@ class DistributedMulImpl1(DistributedOperatorImpl): inputs_original_shape[var_name] = var.shape input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var) input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name) - input_ref_shape = infer_shape(main_block, var, - input_tensor_dist_attr, - input_var_dist_attr) + input_ref_shape = infer_shape( + main_block, var, input_tensor_dist_attr, input_var_dist_attr + ) inputs_ref_shape[var_name] = input_ref_shape var.desc.set_shape(input_ref_shape) - mul_op = main_block.append_op(type='mul', - inputs=inputs, - outputs={'Out': intermediate_var_0}, - attrs=attrs) + mul_op = main_block.append_op( + type='mul', + inputs=inputs, + outputs={'Out': intermediate_var_0}, + attrs=attrs, + ) if intermediate_var_0.shape != ref_shape: intermediate_var_0.desc.set_shape(ref_shape) @@ -2678,8 +3074,9 @@ class DistributedMulImpl1(DistributedOperatorImpl): 'ring_id': group.id, 'use_calc_stream': True, 'use_model_parallel': True, - OP_ROLE_KEY: src_op.attr('op_role') - }) + OP_ROLE_KEY: src_op.attr('op_role'), + }, + ) if Out_var.shape != ref_shape: Out_var.desc.set_shape(ref_shape) @@ -2693,15 +3090,19 @@ class DistributedMulImpl1(DistributedOperatorImpl): for input_varname in mul_op.desc.input_arg_names(): input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) assert input_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - matmulv2_op_dist_attr.set_input_dist_attr(input_varname, - input_dist_attr) + op_dist_attr + ) + matmulv2_op_dist_attr.set_input_dist_attr( + input_varname, input_dist_attr + ) output_varname = mul_op.desc.output_arg_names()[0] output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert output_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - matmulv2_op_dist_attr.set_output_dist_attr(output_varname, - output_dist_attr) + op_dist_attr + ) + matmulv2_op_dist_attr.set_output_dist_attr( + output_varname, output_dist_attr + ) ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr) # allreduce @@ -2713,21 +3114,26 @@ class DistributedMulImpl1(DistributedOperatorImpl): input_var = main_block.var(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) assert tensor_dist_attr is not None - allreduce_op_dist_attr.set_input_dist_attr(input_varname, - tensor_dist_attr) + allreduce_op_dist_attr.set_input_dist_attr( + input_varname, tensor_dist_attr + ) for output_varname in c_allreduce_sum_op.desc.output_arg_names(): output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) assert output_dist_attr is not None, "dist_attr is {}".format( - op_dist_attr) - allreduce_op_dist_attr.set_output_dist_attr(output_varname, - output_dist_attr) - ctx.set_op_dist_attr_for_program(c_allreduce_sum_op, - allreduce_op_dist_attr) + op_dist_attr + ) + allreduce_op_dist_attr.set_output_dist_attr( + output_varname, output_dist_attr + ) + ctx.set_op_dist_attr_for_program( + c_allreduce_sum_op, allreduce_op_dist_attr + ) # init param sync if Weight_var.is_parameter and not op_dist_attr.is_recompute: - _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, - rank_id) + _init_param_sync( + Weight_var, dist_op_context, startup_block, ctx, rank_id + ) @staticmethod def backward(ctx, *args, **kwargs): @@ -2736,7 +3142,6 @@ class DistributedMulImpl1(DistributedOperatorImpl): # ReplicateParallel class DistributedMulImpl2(DistributedOperatorImpl): - def __init__(self, name): super(DistributedMulImpl2, self).__init__(name) @@ -2757,38 +3162,45 @@ class DistributedMulImpl2(DistributedOperatorImpl): vars = main_block.vars # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) process_mesh = dist_attr.process_mesh processes = process_mesh.processes - cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx, - processes, desc_mapping, - cluster) + cost_mapping = build_comp_costs_from_descs( + MulGradOpCost, ctx, processes, desc_mapping, cluster + ) res.append(cost_mapping) # need gradient allreduce var_dim_mapping = dist_attr.get_input_dims_mapping( - backward_op.input("X")[0]) + backward_op.input("X")[0] + ) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] - if batch_size_axis > -1 and mesh_shape[ - batch_size_axis] > 1 and is_parameter_related( - backward_op.input("Y")[0], main_block): + if ( + batch_size_axis > -1 + and mesh_shape[batch_size_axis] > 1 + and is_parameter_related(backward_op.input("Y")[0], main_block) + ): parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} var_names = [backward_op.output('Y@GRAD')[0]] - build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, - cluster) + build_dp_costs( + res, dist_op, ctx, var_names, attrs, parallel_axis, cluster + ) return res def calc_fwd_cost(self, dist_op, ctx, cluster): # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) processes = dist_op.dist_attr.process_mesh.processes - cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes, - desc_mapping, cluster) + cost_mapping = build_comp_costs_from_descs( + MulOpCost, ctx, processes, desc_mapping, cluster + ) res_cost = [cost_mapping] return res_cost @@ -2804,12 +3216,14 @@ class DistributedMulImpl2(DistributedOperatorImpl): if is_dim_shard(x_dims_mapping[-1]): return False if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard( - x_dims_mapping[-2]): + x_dims_mapping[-2] + ): return False if is_dim_shard(y_dims_mapping[-1]): return False if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard( - y_dims_mapping[-2]): + y_dims_mapping[-2] + ): return False return True @@ -2824,14 +3238,16 @@ class DistributedMulImpl2(DistributedOperatorImpl): if is_dim_shard(out_dims_mapping[-1]): return False if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard( - out_dims_mapping[-2]): + out_dims_mapping[-2] + ): return False return True def is_auto_compatible(self, dist_op): - if (not self.is_input_compatible(dist_op)) or \ - (not self.is_output_compatible(dist_op)): + if (not self.is_input_compatible(dist_op)) or ( + not self.is_output_compatible(dist_op) + ): return False if not _is_auto_compatible_for_matmul(dist_op): @@ -2855,8 +3271,10 @@ class DistributedMulImpl2(DistributedOperatorImpl): _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) -register_distributed_operator_impl("mul", - DistributedMulImpl0("column_parallel")) +register_distributed_operator_impl( + "mul", DistributedMulImpl0("column_parallel") +) register_distributed_operator_impl("mul", DistributedMulImpl1("row_parallel")) -register_distributed_operator_impl("mul", - DistributedMulImpl2("replicate_parallel")) +register_distributed_operator_impl( + "mul", DistributedMulImpl2("replicate_parallel") +) diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 6b997d888a4..09f5f6464bc 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -254,17 +254,26 @@ class Parallelizer: self._dist_context.serial_feed_vars["inputs"] + self._dist_context.serial_feed_vars["labels"] ) - if config["use_pure_fp16"]: + self._logger.info( + "Applying AMP-{}-{} ...".format( + config["dtype"], config['level'] + ), + ) + if config['level'] == "o1": + auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) + auto_parallel_amp_pass.apply( + [main_program], [startup_program], self._pass_context + ) + loss = auto_parallel_amp_pass.get_loss() + elif config['level'] in ['o2', 'o3']: config["base_opt"] = optimizer auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass.apply( [main_program], [startup_program], self._pass_context ) + loss = auto_parallel_fp16_pass.get_loss() else: - auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) - auto_parallel_amp_pass.apply( - [main_program], [startup_program], self._pass_context - ) + raise ValueError("AMP level should be one of o1, o2, o3") # apply recompute pass # recompute is then train-only optimization diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index 9e0aaa64485..a61d9aaf4e6 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -18,25 +18,48 @@ from paddle.fluid import unique_name from .pass_base import PassBase, register_pass from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.fluid.data_feeder import check_variable_and_dtype, check_type -from paddle.distributed.auto_parallel.utils import get_loss_op, set_var_dist_attr -from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping -from paddle.distributed.auto_parallel.process_group import get_world_process_group -from paddle.fluid.contrib.mixed_precision.fp16_utils import AutoMixedPrecisionLists -from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_fp32_input, _keep_fp32_output, find_op_index -from paddle.fluid.contrib.mixed_precision.fp16_utils import _valid_types, find_true_post_op, find_true_prev_op -from paddle.fluid.contrib.mixed_precision.fp16_utils import _is_in_black_varnames, _dtype_to_str, _rename_arg -from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute +from paddle.distributed.auto_parallel.utils import ( + get_loss_op, + set_var_dist_attr, +) +from paddle.distributed.auto_parallel.utils import ( + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, +) +from paddle.distributed.auto_parallel.process_group import ( + get_world_process_group, +) +from paddle.fluid.contrib.mixed_precision.fp16_utils import ( + AutoMixedPrecisionLists, +) +from paddle.fluid.contrib.mixed_precision.fp16_utils import ( + _keep_fp32_input, + _keep_fp32_output, + find_op_index, +) +from paddle.fluid.contrib.mixed_precision.fp16_utils import ( + _valid_types, + find_true_post_op, + find_true_prev_op, +) +from paddle.fluid.contrib.mixed_precision.fp16_utils import ( + _is_in_black_varnames, + _dtype_to_str, + _rename_arg, +) +from paddle.distributed.auto_parallel.dist_attribute import ( + OperatorDistributedAttribute, +) from ..auto_parallel.utils import is_forward_op, is_backward_op, is_loss_op world_process_group = get_world_process_group() class AMPState(object): - def __init__(self, block): self._block = block - self._op_fp16_dict = { - } # op_id --> True/False. 'True' means that the current op is in fp16 mode. + self._op_fp16_dict = ( + {} + ) # op_id --> True/False. 'True' means that the current op is in fp16 mode. self._var_name_dict = {} # fwd_op_id --> {old_name: cast_name} self.is_train = False @@ -55,7 +78,8 @@ class AMPState(object): elif int(op.attr('op_role')) == int(OpRole.Backward): if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id: fwd_op_id = dist_op_context.grad_op_id_to_op_id[ - op.desc.original_id()] + op.desc.original_id() + ] if self._is_fp16_op(fwd_op_id) == True: self._op_fp16_dict[op.desc.original_id()] = True elif self._is_fp16_op(fwd_op_id) == False: @@ -78,7 +102,8 @@ class AMPState(object): if op.type == 'create_py_reader' or op.type == 'read': continue if amp_lists.black_varnames is not None and _is_in_black_varnames( - op, amp_lists): + op, amp_lists + ): self._op_fp16_dict[op.desc.original_id()] = False continue if op.type in amp_lists.black_list: @@ -98,17 +123,24 @@ class AMPState(object): continue elif in_var.op is op: prev_op = find_true_prev_op( - ops, op, in_var_name) + ops, op, in_var_name + ) if prev_op is None: continue else: prev_op = in_var.op # if it's one of inputs - if self._is_fp16_op(prev_op.desc.original_id()) == False or \ - prev_op.type in amp_lists.black_list: + if ( + self._is_fp16_op(prev_op.desc.original_id()) + == False + or prev_op.type in amp_lists.black_list + ): is_black_op = True - elif self._is_fp16_op(prev_op.desc.original_id()) == True or \ - prev_op.type in amp_lists.white_list: + elif ( + self._is_fp16_op(prev_op.desc.original_id()) + == True + or prev_op.type in amp_lists.white_list + ): is_white_op = True if is_black_op: self._op_fp16_dict[op.desc.original_id()] = False @@ -131,19 +163,28 @@ class AMPState(object): break if self._is_fp16_op(op.desc.original_id()) == False: num_cast_ops = self._insert_cast_op_forward( - op, idx, core.VarDesc.VarType.FP16, - core.VarDesc.VarType.FP32, dist_context) + op, + idx, + core.VarDesc.VarType.FP16, + core.VarDesc.VarType.FP32, + dist_context, + ) elif self._is_fp16_op(op.desc.original_id()) == True: num_cast_ops = self._insert_cast_op_forward( - op, idx, core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, dist_context) + op, + idx, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.FP16, + dist_context, + ) else: pass idx += num_cast_ops + 1 self._block._sync_with_cpp() - def _insert_cast_op_forward(self, op, idx, src_dtype, dst_dtype, - dist_context): + def _insert_cast_op_forward( + self, op, idx, src_dtype, dst_dtype, dist_context + ): """ only for forward cast modified from paddle.fluid.contrib.mixed_precision @@ -152,38 +193,45 @@ class AMPState(object): var_name_dict = {} for in_name in op.input_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( - op, in_name): + op, in_name + ): continue for in_var_name in op.input(in_name): in_var = self._block._find_var_recursive(in_var_name) if in_var.type not in _valid_types or in_var.dtype == dst_dtype: continue if in_var.dtype == src_dtype: - cast_name = in_var.name + '.cast_' + _dtype_to_str( - dst_dtype) + cast_name = ( + in_var.name + '.cast_' + _dtype_to_str(dst_dtype) + ) out_var = self._block.vars.get(cast_name) var_name_dict[in_var.name] = cast_name consume_op_attr = dist_context.get_op_dist_attr_for_program( - op) + op + ) assert consume_op_attr is not None if out_var is None or out_var.dtype != dst_dtype: # NOTE we make the cast op and var's dist attr as the op that consume the # cast var instead of the op which generates the var in_var_dist_attr = consume_op_attr.get_input_dist_attr( - in_var.name) + in_var.name + ) assert in_var_dist_attr is not None ref_mesh = in_var_dist_attr.process_mesh ref_mapping = in_var_dist_attr.dims_mapping consume_op_attr.set_input_dist_attr( - cast_name, in_var_dist_attr) + cast_name, in_var_dist_attr + ) out_var = self._block.create_var( name=cast_name, dtype=dst_dtype, persistable=False, - stop_gradient=in_var.stop_gradient) - set_var_dist_attr(dist_context, out_var, ref_mapping, - ref_mesh) + stop_gradient=in_var.stop_gradient, + ) + set_var_dist_attr( + dist_context, out_var, ref_mapping, ref_mesh + ) cast_op = self._block._insert_op_without_sync( idx, @@ -193,22 +241,29 @@ class AMPState(object): attrs={ "in_dtype": in_var.dtype, "out_dtype": out_var.dtype, - }) + }, + ) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - cast_op, ref_mesh, ref_mapping, dist_context) + cast_op, ref_mesh, ref_mapping, dist_context + ) num_cast_ops += 1 else: in_var_dist_attr = consume_op_attr.get_input_dist_attr( - in_var.name) + in_var.name + ) consume_op_attr.set_input_dist_attr( - cast_name, in_var_dist_attr) + cast_name, in_var_dist_attr + ) _rename_arg(op, in_var.name, cast_name) else: if op.has_attr('in_dtype'): op._set_attr('in_dtype', dst_dtype) self._var_name_dict[op.desc.original_id()] = var_name_dict - if src_dtype == core.VarDesc.VarType.FP32 and dst_dtype == core.VarDesc.VarType.FP16: + if ( + src_dtype == core.VarDesc.VarType.FP32 + and dst_dtype == core.VarDesc.VarType.FP16 + ): for out_name in op.output_names: if _keep_fp32_output(op, out_name): continue @@ -238,8 +293,9 @@ class AMPState(object): # NOTE: the map in `grad_var_to_var` may be changed when the var is casted, # which will affect the dist_op to insert allreduce_sum op. op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op) - if is_backward_op(grad_op) and (is_forward_op(ops[idx - 1]) - or is_loss_op(ops[idx - 1])): + if is_backward_op(grad_op) and ( + is_forward_op(ops[idx - 1]) or is_loss_op(ops[idx - 1]) + ): if not op_dist_attr.is_recompute: appended_grad_times += 1 @@ -248,14 +304,22 @@ class AMPState(object): if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id: if self._is_fp16_op(grad_op_orig_id) == False: # fp32 num_cast_ops = self._insert_cast_op_backward( - grad_op, idx, core.VarDesc.VarType.FP16, - core.VarDesc.VarType.FP32, dist_context, - appended_grad_times) + grad_op, + idx, + core.VarDesc.VarType.FP16, + core.VarDesc.VarType.FP32, + dist_context, + appended_grad_times, + ) elif self._is_fp16_op(grad_op_orig_id) == True: # fp16 num_cast_ops = self._insert_cast_op_backward( - grad_op, idx, core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, dist_context, - appended_grad_times) + grad_op, + idx, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.FP16, + dist_context, + appended_grad_times, + ) elif grad_op.type == "sum": in_var_name = grad_op.desc.input_arg_names()[0] src_dtype = self._block.var(in_var_name).dtype @@ -270,15 +334,24 @@ class AMPState(object): else: raise ValueError( "'{}' op is not supported in the complete amp pass.".format( - grad_op.type)) + grad_op.type + ) + ) idx += num_cast_ops + 1 self._block._sync_with_cpp() _update_backward_cast_ops(params_grads, dist_context) - def _insert_cast_op_backward(self, grad_op, idx, src_dtype, dst_dtype, - dist_context, appended_grad_times): - """ only for backward cast """ + def _insert_cast_op_backward( + self, + grad_op, + idx, + src_dtype, + dst_dtype, + dist_context, + appended_grad_times, + ): + """only for backward cast""" def _keep_fp32_input(op, in_name): op_type = op.type @@ -299,7 +372,8 @@ class AMPState(object): for in_name in grad_op.input_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( - grad_op, in_name): + grad_op, in_name + ): for in_var_name in grad_op.input(in_name): in_var = self._block._find_var_recursive(in_var_name) assert in_var.dtype == core.VarDesc.VarType.FP32 @@ -309,24 +383,34 @@ class AMPState(object): in_var = self._block._find_var_recursive(in_var_name) if in_var.dtype == src_dtype: consume_op_attr = dist_context.get_op_dist_attr_for_program( - grad_op) + grad_op + ) if in_var_name in self._var_name_dict[fwd_op_id]: # NOTE: if in_var of consume grad_op has been casted before, # it should be renamed and reset dist_attr. cast_name = self._var_name_dict[fwd_op_id][in_var_name] grad_op.desc._rename_input(in_var_name, cast_name) in_var_dist_attr = consume_op_attr.get_input_dist_attr( - in_var_name) + in_var_name + ) consume_op_attr.set_input_dist_attr( - cast_name, in_var_dist_attr) + cast_name, in_var_dist_attr + ) else: - assert in_var.dtype == dst_dtype, "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format( - grad_op.type, in_name, dst_dtype, in_var.dtype, - str(grad_op)) + assert ( + in_var.dtype == dst_dtype + ), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format( + grad_op.type, + in_name, + dst_dtype, + in_var.dtype, + str(grad_op), + ) for out_name in grad_op.output_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output( - grad_op, out_name): + grad_op, out_name + ): for out_var_name in grad_op.output(out_name): out_var = self._block._find_var_recursive(out_var_name) assert out_var.dtype == core.VarDesc.VarType.FP32 @@ -334,7 +418,7 @@ class AMPState(object): for out_var_name in grad_op.output(out_name): out_var = self._block._find_var_recursive(out_var_name) - out_var_name_prefix = out_var_name[:out_var_name.find("@")] + out_var_name_prefix = out_var_name[: out_var_name.find("@")] fwd_var = self._block._find_var_recursive(out_var_name_prefix) # NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype if out_var.dtype != fwd_var.dtype: @@ -345,34 +429,45 @@ class AMPState(object): # NOTE: if out_var of consume grad_op has been casted before, # it should be renamed and reset dist_attr, then we insert cast op to # convert the cast_var to original dtype - consume_op_attr = dist_context.get_op_dist_attr_for_program( - grad_op) + consume_op_attr = ( + dist_context.get_op_dist_attr_for_program(grad_op) + ) fwd_cast_name = self._var_name_dict[fwd_op_id][ - out_var_name_prefix] + out_var_name_prefix + ] suffix = "" if "@RENAME" in out_var_name: - suffix = out_var_name[out_var_name.find("@RENAME"):] + suffix = out_var_name[ + out_var_name.find("@RENAME") : + ] cast_name = fwd_cast_name + "@GRAD" + suffix cast_var = self._block.vars.get(cast_name) if cast_var is None or cast_var.dtype != dst_dtype: grad_op.desc._rename_output(out_var_name, cast_name) - out_var_dist_attr = consume_op_attr.get_output_dist_attr( - out_var_name) + out_var_dist_attr = ( + consume_op_attr.get_output_dist_attr( + out_var_name + ) + ) ref_mesh = out_var_dist_attr.process_mesh ref_mapping = out_var_dist_attr.dims_mapping consume_op_attr.set_output_dist_attr( - cast_name, out_var_dist_attr) + cast_name, out_var_dist_attr + ) assert ref_mapping is not None cast_var = self._block.create_var( name=cast_name, shape=out_var.shape, dtype=dst_dtype, persistable=False, - stop_gradient=out_var.stop_gradient) - set_var_dist_attr(dist_context, cast_var, - ref_mapping, ref_mesh) + stop_gradient=out_var.stop_gradient, + ) + set_var_dist_attr( + dist_context, cast_var, ref_mapping, ref_mesh + ) dist_op_context.grad_var_to_var[ - appended_grad_times][cast_name] = fwd_cast_name + appended_grad_times + ][cast_name] = fwd_cast_name cast_op = self._block._insert_op( idx + 1, @@ -382,13 +477,15 @@ class AMPState(object): attrs={ "in_dtype": cast_var.dtype, "out_dtype": out_var.dtype, - "op_role": OpRole.Backward - }) + "op_role": OpRole.Backward, + }, + ) cast_op._remove_attr("op_role_var") cast_op._remove_attr("op_namescope") cast_op._remove_attr("with_quant_attr") naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - cast_op, ref_mesh, ref_mapping, dist_context) + cast_op, ref_mesh, ref_mapping, dist_context + ) num_cast_ops += 1 else: assert out_var.dtype == dst_dtype @@ -409,15 +506,18 @@ def _update_backward_cast_ops(params_grads, dist_context): for p, g in params_grads: op = g.op if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast': - if int(op.attr('op_role')) == int( - OpRole.Backward) and op.has_attr('op_role_var'): + if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr( + 'op_role_var' + ): op._remove_attr("op_role_var") post_ops = find_true_post_op(main_block.ops, op, g.name) if post_ops: - raise ValueError("The cast op {0}'s output should not be" - "used by a non-optimize op, however, it" - "is used by {1}".format(op, post_ops[0])) + raise ValueError( + "The cast op {0}'s output should not be" + "used by a non-optimize op, however, it" + "is used by {1}".format(op, post_ops[0]) + ) if op == main_block.ops[-1]: continue @@ -425,23 +525,29 @@ def _update_backward_cast_ops(params_grads, dist_context): # add new op in the python and cpp at the same time new_op_desc = main_block.desc.append_op() new_op_desc.copy_from(op.desc) - new_op = paddle.fluid.framework.Operator(block=main_block, - desc=new_op_desc, - type=None, - inputs=None, - outputs=None, - attrs=None) + new_op = paddle.fluid.framework.Operator( + block=main_block, + desc=new_op_desc, + type=None, + inputs=None, + outputs=None, + attrs=None, + ) main_block.ops.append(new_op) # dist attr param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p) output_dist_attr = dist_context.get_tensor_dist_attr_for_program( - main_block.var(op.output_arg_names[0])) + main_block.var(op.output_arg_names[0]) + ) assert param_dist_attr is not None assert output_dist_attr is not None naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - new_op, param_dist_attr.process_mesh, - param_dist_attr.dims_mapping, dist_context) + new_op, + param_dist_attr.process_mesh, + param_dist_attr.dims_mapping, + dist_context, + ) output_dist_attr.process_mesh = param_dist_attr.process_mesh output_dist_attr.dims_mapping = param_dist_attr.dims_mapping @@ -462,26 +568,34 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): grads = [g for _, g in params_grads] check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale') for e in grads: - check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], - 'check_finite_and_unscale') + check_variable_and_dtype( + e, + "x", + ['float16', 'float32', 'float64'], + 'check_finite_and_unscale', + ) found_inf = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ['find_infinite_scale', 'tmp'])), + name=unique_name.generate_with_ignorable_key( + ".".join(['find_infinite_scale', 'tmp']) + ), shape=[1], dtype='bool', type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=False) + stop_gradient=False, + ) set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks) inputs = {'X': grads, 'Scale': loss_scaling} outputs = {'Out': grads, 'FoundInfinite': found_inf} attrs = {'op_role': OpRole.Optimize} - new_op = main_block.append_op(type='check_finite_and_unscale', - inputs=inputs, - outputs=outputs, - attrs=attrs) + new_op = main_block.append_op( + type='check_finite_and_unscale', + inputs=inputs, + outputs=outputs, + attrs=attrs, + ) new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr.process_mesh = world_process_group.ranks @@ -491,17 +605,18 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): for g in grads: g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g) assert g_dist_attr is not None - new_op_dist_attr.set_input_dims_mapping(g.name, - g_dist_attr.dims_mapping) - new_op_dist_attr.set_output_dims_mapping(g.name, - g_dist_attr.dims_mapping) + new_op_dist_attr.set_input_dims_mapping( + g.name, g_dist_attr.dims_mapping + ) + new_op_dist_attr.set_output_dims_mapping( + g.name, g_dist_attr.dims_mapping + ) dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) return grads, found_inf @register_pass("auto_parallel_amp") class AMPPass(PassBase): - def __init__(self): super(AMPPass, self).__init__() self.set_attr("loss", None) @@ -517,6 +632,7 @@ class AMPPass(PassBase): self.set_attr("use_dynamic_loss_scaling", False) self.set_attr("input_data", []) self.set_attr("params_grads", []) + self.set_attr("dtype", "") # fp16/bf16 self._loss = None self._loss_scaling = None self._num_good_steps = None @@ -524,6 +640,8 @@ class AMPPass(PassBase): self._loss = None def _check_self(self): + if self.get_attr("dtype") not in ["float16", "bfloat16"]: + return False if self.get_attr("init_loss_scaling") < 0: return False if self.get_attr("incr_every_n_steps") < 0: @@ -548,11 +666,13 @@ class AMPPass(PassBase): def _apply_single_impl(self, main_program, startup_program, context): self.dist_context = self.get_attr("dist_context") params_grads = self.get_attr("params_grads") + self.amp_dtype = self.get_attr("dtype") amp_lists = AutoMixedPrecisionLists( set(self.get_attr("custom_white_list")), set(self.get_attr("custom_black_list")), - set(self.get_attr("custom_black_varnames"))) + set(self.get_attr("custom_black_varnames")), + ) with paddle.static.program_guard(main_program, startup_program): amp_state = AMPState(main_program.global_block()) @@ -566,10 +686,13 @@ class AMPPass(PassBase): self._init_amp_var() self._scale_loss() - if self.get_attr("use_dynamic_loss_scaling" - ) or self.get_attr("init_loss_scaling") != 1.0: + if ( + self.get_attr("use_dynamic_loss_scaling") + or self.get_attr("init_loss_scaling") != 1.0 + ): grads, found_inf = _check_and_update_gradient( - params_grads, self._loss_scaling, self.dist_context) + params_grads, self._loss_scaling, self.dist_context + ) if self.get_attr("use_dynamic_loss_scaling"): self._update_loss_scaling(grads, found_inf) @@ -580,9 +703,14 @@ class AMPPass(PassBase): shape=[1], value=self.get_attr("init_loss_scaling"), dtype='float32', - persistable=True) - set_var_dist_attr(self.dist_context, self._loss_scaling, [-1], - world_process_group.ranks) + persistable=True, + ) + set_var_dist_attr( + self.dist_context, + self._loss_scaling, + [-1], + world_process_group.ranks, + ) if self.get_attr("use_dynamic_loss_scaling"): self._num_good_steps = paddle.static.create_global_var( @@ -590,18 +718,28 @@ class AMPPass(PassBase): shape=[1], value=0, dtype='int32', - persistable=True) - set_var_dist_attr(self.dist_context, self._num_good_steps, [-1], - world_process_group.ranks) + persistable=True, + ) + set_var_dist_attr( + self.dist_context, + self._num_good_steps, + [-1], + world_process_group.ranks, + ) self._num_bad_steps = paddle.static.create_global_var( name=unique_name.generate("num_bad_steps"), shape=[1], value=0, dtype='int32', - persistable=True) - set_var_dist_attr(self.dist_context, self._num_bad_steps, [-1], - world_process_group.ranks) + persistable=True, + ) + set_var_dist_attr( + self.dist_context, + self._num_bad_steps, + [-1], + world_process_group.ranks, + ) def _scale_loss(self): @@ -613,7 +751,8 @@ class AMPPass(PassBase): assert loss is not None loss_op = loss.op loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program( - loss_op) + loss_op + ) if loss.dtype != core.VarDesc.VarType.FP32: # cast loss here will change the effective loss tensor for the computation graph @@ -626,10 +765,12 @@ class AMPPass(PassBase): tmp_name = unique_name.generate(loss.name + ".cast_fp32") cast_loss = main_block.create_var(name=tmp_name, dtype=dtype) loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program( - loss) + loss + ) ref_mesh = loss_op_dist_attr.process_mesh self.dist_context.set_tensor_dist_attr_for_program( - cast_loss, loss_dist_attr) + cast_loss, loss_dist_attr + ) loss_op_idx = find_op_index(main_block.desc, loss_op.desc) cast_op = main_block._insert_op( @@ -641,16 +782,21 @@ class AMPPass(PassBase): "in_dtype": loss.dtype, "out_dtype": core.VarDesc.VarType.FP32, 'op_role': loss_op.all_attrs()[OP_ROLE_KEY], - }) + }, + ) - loss_op._set_attr(OP_ROLE_KEY, - core.op_proto_and_checker_maker.OpRole.Forward) + loss_op._set_attr( + OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward + ) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - cast_op, ref_mesh, [-1], self.dist_context) + cast_op, ref_mesh, [-1], self.dist_context + ) loss = loss.astype('float32') - if self.get_attr("use_dynamic_loss_scaling" - ) or self.get_attr("init_loss_scaling") != 1.0: + if self.amp_dtype == "float16" and ( + self.get_attr("use_dynamic_loss_scaling") + or self.get_attr("init_loss_scaling") != 1.0 + ): loss_op_idx = find_op_index(main_block.desc, loss_op.desc) @@ -660,63 +806,76 @@ class AMPPass(PassBase): name=unique_name.generate("scaled_loss"), shape=loss.shape, dtype=loss.dtype, - persistable=loss.persistable) - set_var_dist_attr(self.dist_context, self._scaled_loss, [-1], - ref_mesh) + persistable=loss.persistable, + ) + set_var_dist_attr( + self.dist_context, self._scaled_loss, [-1], ref_mesh + ) elementwise_mul_op = main_block._insert_op( loss_op_idx + 1, type='elementwise_mul', - inputs={ - 'X': [loss], - 'Y': [self._loss_scaling] - }, + inputs={'X': [loss], 'Y': [self._loss_scaling]}, outputs={'Out': [self._scaled_loss]}, attrs={ 'op_role': loss_op.all_attrs()[OP_ROLE_KEY], - }) - loss_op._set_attr(OP_ROLE_KEY, - core.op_proto_and_checker_maker.OpRole.Forward) + }, + ) + loss_op._set_attr( + OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward + ) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - elementwise_mul_op, ref_mesh, [-1], self.dist_context) + elementwise_mul_op, ref_mesh, [-1], self.dist_context + ) # backward first_backward_op = main_block.ops[loss_op_idx + 2] - assert first_backward_op.type == "fill_constant" and int( - first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 + assert ( + first_backward_op.type == "fill_constant" + and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 + ) self._scaled_loss_grad = main_block.create_var( name=unique_name.generate("scaled_loss") + "@GRAD", shape=loss.shape, dtype=loss.dtype, - persistable=loss.persistable) - set_var_dist_attr(self.dist_context, self._scaled_loss_grad, [-1], - ref_mesh) + persistable=loss.persistable, + ) + set_var_dist_attr( + self.dist_context, self._scaled_loss_grad, [-1], ref_mesh + ) pre_grad_name = first_backward_op.output_arg_names[0] - first_backward_op._rename_output(pre_grad_name, - self._scaled_loss_grad.name) + first_backward_op._rename_output( + pre_grad_name, self._scaled_loss_grad.name + ) # FIXME(JZ-LIANG) a trick to insert backward op main_block._sync_with_cpp() elementwise_mul_grad_op_desc = main_block.desc._insert_op( - loss_op_idx + 3) + loss_op_idx + 3 + ) elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad") elementwise_mul_grad_op_desc.set_input( - 'Out@GRAD', [self._scaled_loss_grad.name]) + 'Out@GRAD', [self._scaled_loss_grad.name] + ) elementwise_mul_grad_op_desc.set_input('X', [loss.name]) - elementwise_mul_grad_op_desc.set_input('Y', - [self._loss_scaling.name]) + elementwise_mul_grad_op_desc.set_input( + 'Y', [self._loss_scaling.name] + ) elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name]) elementwise_mul_grad_op_desc.set_output('Y@GRAD', []) elementwise_mul_grad_op_desc._set_attr( - OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward) + OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward + ) elementwise_mul_grad_op_desc._set_attr('axis', -1) elementwise_mul_grad_op = paddle.fluid.framework.Operator( - main_block, elementwise_mul_grad_op_desc) + main_block, elementwise_mul_grad_op_desc + ) main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op) main_block._sync_with_cpp() elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3] assert elementwise_mul_grad_op.type == "elementwise_mul_grad" naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context) + elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context + ) else: self._scaled_loss = loss @@ -728,31 +887,39 @@ class AMPPass(PassBase): main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() - check_variable_and_dtype(self._loss_scaling, "prev_loss_scaling", - ['float32', 'float64'], "update_loss_scaling") + check_variable_and_dtype( + self._loss_scaling, + "prev_loss_scaling", + ['float32', 'float64'], + "update_loss_scaling", + ) check_type(grads, 'x', (tuple, list), 'update_loss_scaling') for e in grads: - check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], - 'update_loss_scaling') + check_variable_and_dtype( + e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling' + ) if e.dtype == core.VarDesc.VarType.FP16: - assert self._loss_scaling.dtype == core.VarDesc.VarType.FP32, \ - "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." + assert ( + self._loss_scaling.dtype == core.VarDesc.VarType.FP32 + ), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." else: - assert self._loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x." + assert ( + self._loss_scaling.dtype == e.dtype + ), "The dtype of prev_loss_scaling should be equal to the dtype of x." inputs = { 'X': grads, 'FoundInfinite': found_inf, 'PrevLossScaling': self._loss_scaling, 'InGoodSteps': self._num_good_steps, - 'InBadSteps': self._num_bad_steps + 'InBadSteps': self._num_bad_steps, } outputs = { 'Out': grads, 'LossScaling': self._loss_scaling, 'OutGoodSteps': self._num_good_steps, - 'OutBadSteps': self._num_bad_steps + 'OutBadSteps': self._num_bad_steps, } attrs = { @@ -761,13 +928,15 @@ class AMPPass(PassBase): 'incr_ratio': self.get_attr("incr_ratio"), 'decr_ratio': self.get_attr("decr_ratio"), 'stop_update': self.get_attr("stop_update"), - 'op_role': OpRole.Optimize + 'op_role': OpRole.Optimize, } - new_op = main_block.append_op(type='update_loss_scaling', - inputs=inputs, - outputs=outputs, - attrs=attrs) + new_op = main_block.append_op( + type='update_loss_scaling', + inputs=inputs, + outputs=outputs, + attrs=attrs, + ) new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr.process_mesh = world_process_group.ranks @@ -777,10 +946,22 @@ class AMPPass(PassBase): for g in grads: g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g) assert g_dist_attr is not None - new_op_dist_attr.set_input_dims_mapping(g.name, - g_dist_attr.dims_mapping) - new_op_dist_attr.set_output_dims_mapping(g.name, - g_dist_attr.dims_mapping) + new_op_dist_attr.set_input_dims_mapping( + g.name, g_dist_attr.dims_mapping + ) + new_op_dist_attr.set_output_dims_mapping( + g.name, g_dist_attr.dims_mapping + ) self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) main_block._sync_with_cpp() + + def get_loss(self): + # the amp might change the effective loss variable for network and + # therefore would affect the subsequent passes that rely on the loss. + # return the effective loss after amp pass. + + if self._loss: + return self._loss + else: + return self.get_attr("loss") diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 34684c6ca41..ac73699acb9 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -27,14 +27,13 @@ from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.process_group import ( get_world_process_group, ) -from paddle.fluid.contrib.mixed_precision.fp16_utils import ( +from paddle.fluid.contrib.mixed_precision.fp16_lists import ( AutoMixedPrecisionLists, ) from paddle.fluid.contrib.mixed_precision.fp16_utils import ( _keep_layer_norm_scale_bias_to_fp32, _need_keep_fp32, _valid_types, - _dtype_to_str, ) from paddle.distributed.auto_parallel.dist_attribute import ( OperatorDistributedAttribute, @@ -55,6 +54,23 @@ __amp_skip_ops__ = [ 'while', 'cast', ] +__target_dtype__ = None + + +def _dtype_to_str(dtype): + """ + Convert specific variable type to its corresponding string. + Args: + dtype (VarType): Variable type. + """ + if dtype == core.VarDesc.VarType.FP16: + # TODO(Xreki): change the returned str to "bf16" for BF16 data type. + # Currently too many codes use "cast_fp16" as key. + return 'fp16' + elif dtype == core.VarDesc.VarType.BF16: + return 'bf16' + else: + return 'fp32' def set_op_dtype_to_fp16(op): @@ -62,14 +78,20 @@ def set_op_dtype_to_fp16(op): op.has_attr('in_dtype') and op.attr('in_dtype') == core.VarDesc.VarType.FP32 ): - op._set_attr('in_dtype', core.VarDesc.VarType.FP16) + op._set_attr('in_dtype', __target_dtype__) if ( op.has_attr('out_dtype') and op.attr('out_dtype') == core.VarDesc.VarType.FP32 ): - op._set_attr('out_dtype', core.VarDesc.VarType.FP16) + op._set_attr('out_dtype', __target_dtype__) if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32: - op._set_attr('dtype', core.VarDesc.VarType.FP16) + op._set_attr('dtype', __target_dtype__) + + if __target_dtype__ == core.VarDesc.VarType.BF16: + if op.has_attr('use_mkldnn'): + op._set_attr('use_mkldnn', True) + if op.has_attr('mkldnn_data_type'): + op._set_attr('mkldnn_data_type', 'bfloat16') # adapot for backward op @@ -156,6 +178,7 @@ class FP16State(object): list ) # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]} self.is_train = False + self.out_var_op_deps = {} def _is_fp16_op(self, op_id): return self._op_fp16_dict.get(op_id, None) @@ -169,6 +192,14 @@ class FP16State(object): # assume all backward block are behind forward blocks for block in self.program.blocks: for op in block.ops: + for name in op.output_arg_names: + if name not in self.out_var_op_deps: + self.out_var_op_deps[name] = [op.desc.original_id()] + else: + self.out_var_op_deps[name].extend( + [op.desc.original_id()] + ) + self._mark_op(op) # set forward tensor dtype @@ -192,6 +223,18 @@ class FP16State(object): if op.type == "assign" and "array_" in op.input_arg_names[0]: self._op_fp16_dict[op.desc.original_id()] = False return + # If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var. + if op.type == "assign": + out_name = op.output_arg_names[0] + if len(self.out_var_op_deps[out_name]) > 1: + if not self._op_fp16_dict[ + self.out_var_op_deps[out_name][0] + ]: + self._op_fp16_dict[op.desc.original_id()] = False + else: + self._op_fp16_dict[op.desc.original_id()] = True + return + if _need_keep_fp32( op, self.amp_list.unsupported_list, self.use_fp16_guard ): @@ -228,7 +271,7 @@ class FP16State(object): return if var.dtype == core.VarDesc.VarType.FP32: - var.desc.set_dtype(core.VarDesc.VarType.FP16) + var.desc.set_dtype(__target_dtype__) def resolute_tensor_dtype(self, block): @@ -260,7 +303,7 @@ class FP16State(object): out_var = block.vars.get(out_var_name) if out_var is None or out_var.type not in _valid_types: continue - if out_var.dtype == core.VarDesc.VarType.FP16: + if out_var.dtype == __target_dtype__: out_var.desc.set_dtype(core.VarDesc.VarType.FP32) elif is_backward_op(op): if self._is_fp16_op(op.desc.original_id()) == True: @@ -276,7 +319,7 @@ class FP16State(object): out_var = block.vars.get(out_var_name) if out_var is None or out_var.type not in _valid_types: continue - if out_var.dtype == core.VarDesc.VarType.FP16: + if out_var.dtype == __target_dtype__: out_var.desc.set_dtype(core.VarDesc.VarType.FP32) def cast_block(self, block): @@ -295,7 +338,7 @@ class FP16State(object): op, idx, block, - core.VarDesc.VarType.FP16, + __target_dtype__, core.VarDesc.VarType.FP32, self.dist_context, ) @@ -305,7 +348,7 @@ class FP16State(object): idx, block, core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, + __target_dtype__, self.dist_context, ) elif is_backward_op(op): @@ -315,7 +358,7 @@ class FP16State(object): op, idx, block, - core.VarDesc.VarType.FP16, + __target_dtype__, core.VarDesc.VarType.FP32, self.dist_context, ) @@ -325,7 +368,7 @@ class FP16State(object): idx, block, core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, + __target_dtype__, self.dist_context, ) elif op.type == "sum": @@ -399,6 +442,9 @@ class FP16State(object): dist_context, cast_var, ref_mapping, ref_mesh ) + op_namescope = "/" + if op.has_attr('op_namescope'): + op_namescope = op.attr('op_namescope') cast_op = block._insert_op_without_sync( idx, type="cast", @@ -410,6 +456,9 @@ class FP16State(object): OP_ROLE_KEY: OpRole.Forward, }, ) + cast_op._set_attr( + 'op_namescope', op_namescope + ) # for recompute naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, ref_mapping, dist_context ) @@ -455,63 +504,79 @@ class FP16State(object): ) in self.forward_input_cast_ops[forward_op_id]: # rename input - assert src_name in op.input( - slot_name - ), "var: {} not in op's {}. {}".format(src_name, slot_name, str(op)) - src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name) - assert src_var_dist_attr is not None - op._rename_input(src_name, cast_name) - grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr) + # some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy + if op.type != "scale" and slot_name in op.input_names: + assert src_name in op.input( + slot_name + ), "var: {} not in op's {}. {}".format( + src_name, slot_name, str(op) + ) + src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name) + assert src_var_dist_attr is not None + op._rename_input(src_name, cast_name) + grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr) + + # NOTE Special for scale op, scale op's grad op is scale, + # so slot name map rule could not apply to grad scale op + # cast_name: mean_0.tmp_0.cast_bf16, src_name: mean_0.tmp_0, dst_dtype: paddle.bfloat16, src_dtype: paddle.float32, slot_name: X. + if op.type == "scale": + grad_slot_name = "X" # create cast grad - grad_slot_name = slot_name + "@GRAD" - assert grad_slot_name in op.output_names - if len(op.output(grad_slot_name)) == 0: - var = block.var(src_name) - assert var.stop_gradient is True - continue - assert len(op.output(grad_slot_name)) == 1 - grad_name = op.output(grad_slot_name)[0] - grad = block.var(grad_name) - grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name) - assert grad_dist_attr is not None, "{}".format(grad_name) - ref_mesh = grad_dist_attr.process_mesh - ref_mapping = grad_dist_attr.dims_mapping - - cast_grad = block.create_var( - name=unique_name.generate_with_ignorable_key( - "".join([cast_name, '@GRAD']) - ), - dtype=dst_dtype, - shape=grad.shape, - type=grad.type, - persistable=grad.persistable, - stop_gradient=grad.stop_gradient, - ) - dist_context.set_tensor_dist_attr_for_program( - cast_grad, grad_dist_attr - ) - op._rename_output(grad_name, cast_grad.name) - grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr) - - # add cast - cast_op = block._insert_op_without_sync( - idx + 1, - type="cast", - inputs={"X": [cast_grad.name]}, - outputs={"Out": [grad.name]}, - attrs={ - "in_dtype": dst_dtype, - "out_dtype": src_dtype, - OP_ROLE_KEY: OpRole.Backward, - }, - ) - grad.desc.set_dtype(src_dtype) + else: + grad_slot_name = slot_name + "@GRAD" - naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - cast_op, ref_mesh, ref_mapping, dist_context - ) - num_cast_ops += 1 + if grad_slot_name in op.output_names: + # some forward input maybe stop_gradient=True, e.g. input_mask + if len(op.output(grad_slot_name)) == 0: + continue + assert ( + len(op.output(grad_slot_name)) == 1 + ), "[{}], Current Op: {}".format(grad_slot_name, str(op)) + + grad_name = op.output(grad_slot_name)[0] + grad = block.var(grad_name) + grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name) + assert grad_dist_attr is not None, "{}".format(grad_name) + ref_mesh = grad_dist_attr.process_mesh + ref_mapping = grad_dist_attr.dims_mapping + + cast_grad = block.create_var( + name=unique_name.generate_with_ignorable_key( + "".join([cast_name, '@GRAD']) + ), + dtype=dst_dtype, + shape=grad.shape, + type=grad.type, + persistable=grad.persistable, + stop_gradient=grad.stop_gradient, + ) + dist_context.set_tensor_dist_attr_for_program( + cast_grad, grad_dist_attr + ) + op._rename_output(grad_name, cast_grad.name) + grad_op_attr.set_output_dist_attr( + cast_grad.name, grad_dist_attr + ) + + # add cast + cast_op = block._insert_op_without_sync( + idx + 1, + type="cast", + inputs={"X": [cast_grad.name]}, + outputs={"Out": [grad.name]}, + attrs={ + "in_dtype": dst_dtype, + "out_dtype": src_dtype, + OP_ROLE_KEY: OpRole.Backward, + }, + ) + grad.desc.set_dtype(src_dtype) + + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cast_op, ref_mesh, ref_mapping, dist_context + ) + num_cast_ops += 1 return num_cast_ops @@ -573,7 +638,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): def _split_grads(params_grads): grads = [g for _, g in params_grads] fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] - fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] + fp16_grads = [g for g in grads if g.dtype == __target_dtype__] assert len(fp32_grads) + len(fp16_grads) == len( grads ), "Data types of all grads must be either fp16 or fp32." @@ -633,17 +698,15 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): # TODO to support CUDAPinned/NPU/XPU Places if direction == "D2H": dst_place_type = 0 - elif direction == "D2H": - dst_place_type = 1 else: raise NotImplementedError( - "direction [{}] is not supported yet.".format(direction) + f"direction [{direction}] is not supported yet." ) attrs = {'dst_place_type': dst_place_type} new_op = block._insert_op_without_sync( index=idx, - type='memcpy', + type='memcpy_d2h', inputs={'X': [src_var]}, outputs={'Out': [output_var]}, attrs=attrs, @@ -678,17 +741,17 @@ def cast_startup_program(): for op in startup_program.global_block().ops: if is_initialization_op(op): output_name = op.output_arg_names[0] - if ( - param_to_dtype.get(output_name, None) - == core.VarDesc.VarType.FP16 - ): + if param_to_dtype.get(output_name, None) == __target_dtype__: assert op.has_attr( 'dtype' ), "initialization op is supported to has dtype attribute but got {}.".format( str(op) ) + out_var = startup_program.global_block().var(output_name) + if out_var.dtype == core.VarDesc.VarType.FP32: + out_var.desc.set_dtype(__target_dtype__) if op.attr('dtype') == core.VarDesc.VarType.FP32: - op._set_attr('dtype', core.VarDesc.VarType.FP16) + op._set_attr('dtype', __target_dtype__) @register_pass("auto_parallel_fp16") @@ -701,14 +764,44 @@ class FP16Pass(AMPPass): # in distributed scenario, all ranks should have the same modification. def _apply_single_impl(self, main_program, startup_program, context): self.dist_context = self.get_attr("dist_context") + self.target_dtype = self.get_attr("dtype") params_grads = self.get_attr("params_grads") + self.use_optimizer_fp16 = self.get_attr("use_optimizer_fp16", None) + if self.use_optimizer_fp16 is None: + self.use_optimizer_fp16 = self.get_attr("level", None) == "o3" + + # swith enviroment for fp16 / bf16. + if self.target_dtype == "float16": + __target_dtype = core.VarDesc.VarType.FP16 + + elif self.target_dtype == "bfloat16": + __target_dtype = core.VarDesc.VarType.BF16 + else: + raise NotImplementedError( + "target dtype [{}] is for amp o2 not supported yet.".format( + self.target_dtype + ) + ) + global __target_dtype__ + __target_dtype__ = __target_dtype amp_list = AutoMixedPrecisionLists( set(self.get_attr("custom_white_list")), set(self.get_attr("custom_black_list")), - None, + dtype=self.target_dtype, ) - + amp_list.unsupported_list -= { + "conditional_block_grad", + "conditional_block", + "conditional_block_infer", + "select_input", + "while", + "while_grad", + "cast", + "tensor_array_to_tensor", + "lod_array_length", + "write_to_array", + } # NOTE don't not change input data dtype, since it is controled by dataloader # and which is out of control of FP16 Pass input_data_var_names = [var.name for var in self.get_attr("input_data")] @@ -726,93 +819,96 @@ class FP16Pass(AMPPass): cast_startup_program() if is_train: - with paddle.static.program_guard(main_program, startup_program): - # TODO (JZ-LIANG)support cast forward program only when inference - self._init_amp_var() - self._scale_loss() - - grads, fp32_grads, fp16_grads = _split_grads(params_grads) - - if ( - self.get_attr("use_dynamic_loss_scaling") - or self.get_attr("init_loss_scaling") != 1.0 - ): - found_infs = [] - if fp32_grads: + if self.target_dtype == "fp16": + with paddle.static.program_guard(main_program, startup_program): + # TODO (JZ-LIANG)support cast forward program only when inference + self._init_amp_var() + self._scale_loss() + + grads, fp32_grads, fp16_grads = _split_grads(params_grads) + + if ( + self.get_attr("use_dynamic_loss_scaling") + or self.get_attr("init_loss_scaling") != 1.0 + ): + found_infs = [] + if fp32_grads: + with main_program._optimized_guard([]): + _, found_inf_fp32 = _check_and_update_gradient( + fp32_grads, + self._loss_scaling, + "@fp32", + self.dist_context, + ) + found_infs.append(found_inf_fp32) + if fp16_grads: + with main_program._optimized_guard([]): + _, found_inf_fp16 = _check_and_update_gradient( + fp16_grads, + self._loss_scaling, + "@fp16", + self.dist_context, + ) + found_infs.append(found_inf_fp16) with main_program._optimized_guard([]): - _, found_inf_fp32 = _check_and_update_gradient( - fp32_grads, - self._loss_scaling, - "@fp32", + block = main_program.global_block() + + all_infs = paddle.fluid.layers.concat(found_infs) + set_var_dist_attr( self.dist_context, + all_infs, + [-1], + world_process_group.ranks, ) - found_infs.append(found_inf_fp32) - if fp16_grads: - with main_program._optimized_guard([]): - _, found_inf_fp16 = _check_and_update_gradient( - fp16_grads, - self._loss_scaling, - "@fp16", + new_op = block.ops[-1] + assert new_op.type == "concat" + _set_op_dist_attr_with_ranks( + new_op, + world_process_group.ranks, + block, self.dist_context, ) - found_infs.append(found_inf_fp16) - with main_program._optimized_guard([]): - block = main_program.global_block() - - all_infs = paddle.fluid.layers.concat(found_infs) - set_var_dist_attr( - self.dist_context, - all_infs, - [-1], - world_process_group.ranks, - ) - new_op = block.ops[-1] - assert new_op.type == "concat" - _set_op_dist_attr_with_ranks( - new_op, - world_process_group.ranks, - block, - self.dist_context, - ) - found_inf = paddle.fluid.layers.reduce_any(all_infs) - set_var_dist_attr( - self.dist_context, - found_inf, - [-1], - world_process_group.ranks, - ) - new_op = block.ops[-1] - assert new_op.type == "reduce_any" - _set_op_dist_attr_with_ranks( - new_op, - world_process_group.ranks, - block, - self.dist_context, - ) + found_inf = paddle.fluid.layers.reduce_any(all_infs) + set_var_dist_attr( + self.dist_context, + found_inf, + [-1], + world_process_group.ranks, + ) + new_op = block.ops[-1] + assert new_op.type == "reduce_any" + _set_op_dist_attr_with_ranks( + new_op, + world_process_group.ranks, + block, + self.dist_context, + ) - if self.get_attr("use_dynamic_loss_scaling"): - with main_program._optimized_guard([]): - if fp32_grads: - self._update_loss_scaling(fp32_grads, found_inf) - if fp16_grads: - self._update_loss_scaling(fp16_grads, found_inf) + if self.get_attr("use_dynamic_loss_scaling"): + with main_program._optimized_guard([]): + if fp32_grads: + self._update_loss_scaling(fp32_grads, found_inf) + if fp16_grads: + self._update_loss_scaling(fp16_grads, found_inf) # modify optimizer base_opt = self.get_attr("base_opt") base_opt._multi_precision = True - if self.get_attr("use_optimizer_fp16"): + if self.use_optimizer_fp16: base_opt._multi_precision = False - if isinstance( - base_opt, (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW) - ): - with main_program._optimized_guard([]): - # found_inf = paddle.tensor.creation._memcpy( - # found_inf, paddle.CPUPlace()) - insert_idx = _get_memcopy_idx(block, found_inf) - found_inf = _insert_memcopy( - block, insert_idx, found_inf, self.dist_context - ) - base_opt._set_auxiliary_var('found_inf', found_inf.name) - elif hasattr(base_opt, "_set_auxiliary_var"): - base_opt._set_auxiliary_var('found_inf', found_inf.name) + if self.target_dtype == "fp16": + if isinstance( + base_opt, + (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW), + ): + with main_program._optimized_guard([]): + # found_inf = paddle.tensor.creation._memcpy( + # found_inf, paddle.CPUPlace()) + insert_idx = _get_memcopy_idx(block, found_inf) + found_inf = _insert_memcopy( + block, insert_idx, found_inf, self.dist_context + ) + base_opt._set_auxiliary_var('found_inf', found_inf.name) + elif hasattr(base_opt, "_set_auxiliary_var"): + base_opt._set_auxiliary_var('found_inf', found_inf.name) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 9079a0b7535..ee3f855b1c4 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_random_ctrl MODULES test_random_ctrl ENVS ${dist_ENVS}) set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_amp_o2_pass MODULES test_amp_o2_pass ENVS ${dist_ENVS}) + set_tests_properties(test_amp_o2_pass PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" + TIMEOUT 50) py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS ${dist_ENVS}) set_tests_properties(test_iterable_dataset diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py b/python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py new file mode 100644 index 00000000000..4291dc296db --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py @@ -0,0 +1,142 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import re +import unittest + +import numpy as np +from get_gpt_model import FakeDataset, generate_model + +import paddle +from paddle.distributed.fleet import auto +from paddle.fluid.framework import core + +paddle.enable_static() + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +def apply_pass(use_amp=False, amp_dtype="bfloat16"): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + + if use_amp: + amp = strategy.amp + amp.enable = True + amp.dtype = amp_dtype + amp.level = "o2" + amp.custom_black_list = [ + 'c_softmax_with_cross_entropy', + 'elementwise_div', + 'reduce_sum', + ] + + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestShardingStage2WithNewEXE(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2022) + np.random.seed(2022) + random.seed(2022) + place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_amp=False, amp_dtype="bfloat16"): + reset_prog() + + strategy = apply_pass(use_amp, amp_dtype) + # clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + clip = None + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("mp") + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_bf16(self, program): + num_bf16 = 0 + num_fp16 = 0 + num_fp32 = 0 + + for p in program.all_parameters(): + if p.dtype == core.VarDesc.VarType.FP32: + num_fp32 += 1 + if p.dtype == core.VarDesc.VarType.FP16: + num_fp16 += 1 + if p.dtype == core.VarDesc.VarType.BF16: + num_bf16 += 1 + + self.assertEqual(num_bf16, 25) + self.assertEqual(num_fp16, 0) + self.assertEqual(num_fp32, 11) + + def test_param_grad_fuse_overlap(self): + # std + mp_engine = self.get_engine(use_amp=False) + mp_history = mp_engine.fit( + self.dataset, + 3, + epochs=1, + steps_per_epoch=self.batch_num, + log_freq=1, + batch_size=self.batch_size, + ) + loss0 = mp_history.history['loss'][0] + + # bf16 + mp_bf16_engine = self.get_engine(use_amp=True) + if not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000: + return + + mp_bf16_history = mp_bf16_engine.fit( + self.dataset, + 3, + epochs=1, + steps_per_epoch=self.batch_num, + log_freq=1, + batch_size=self.batch_size, + ) + loss1 = mp_bf16_history.history['loss'][0] + np.testing.assert_allclose(loss0, loss1, atol=1e-3, rtol=1e-2) + + self.check_bf16(mp_bf16_engine.main_program) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py index 6d96cd13773..c5096d88bc2 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py @@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None): ] amp.init_loss_scaling = 32768 amp.use_fp16_guard = False - amp.use_pure_fp16 = level in ["o2", "o3"] + amp.level = level amp.use_optimizer_fp16 = level == "o3" print("amp level: ", level) return strategy diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py new file mode 100644 index 00000000000..9c3d9797bc6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py @@ -0,0 +1,55 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys +import tempfile +import unittest + + +class TestAMPO2(unittest.TestCase): + def test_bf16(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "amp_o2_pass.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = ( + [sys.executable, "-u"] + + coverage_args + + [ + "-m", + "paddle.distributed.launch", + "--devices", + "0,1", + "--log_dir", + tmp_dir.name, + launch_model_path, + ] + ) + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py index 4d17ea10dcb..ec570a99364 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py @@ -13,13 +13,13 @@ # limitations under the License. import os + # import yaml import unittest from paddle.distributed.fleet import auto class TestStrategy(unittest.TestCase): - def test_default_config(self): strategy = auto.Strategy() @@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase): amp = strategy.amp self.assertEqual(amp.enable, False) + self.assertAlmostEqual(amp.dtype, "float16") + self.assertAlmostEqual(amp.level, "o1") self.assertAlmostEqual(amp.init_loss_scaling, 32768.0) self.assertEqual(amp.incr_every_n_steps, 1000) self.assertEqual(amp.decr_every_n_nan_or_inf, 2) @@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase): self.assertEqual(amp.custom_black_list, []) self.assertEqual(amp.custom_white_list, []) self.assertEqual(amp.custom_black_varnames, []) - self.assertEqual(amp.use_pure_fp16, False) - self.assertEqual(amp.use_fp16_guard, True) + self.assertEqual(amp.use_fp16_guard, False) self.assertEqual(amp.use_optimizer_fp16, False) sharding = strategy.sharding @@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase): amp.custom_white_list = ["x"] amp.custom_black_list = ["y"] amp.custom_black_varnames = ["z"] - amp.use_pure_fp16 = True amp.use_fp16_guard = False amp.use_optimizer_fp16 = True self.assertEqual(amp.enable, True) @@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase): self.assertEqual(amp.custom_white_list, ["x"]) self.assertEqual(amp.custom_black_list, ["y"]) self.assertEqual(amp.custom_black_varnames, ["z"]) - self.assertEqual(amp.use_pure_fp16, True) self.assertEqual(amp.use_fp16_guard, False) self.assertEqual(amp.use_optimizer_fp16, True) -- GitLab