diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py index 4112a8606a2774e6fd14f71fa2f9767a8caee827..1a0cb322cf499458323da2ab570d4a0f09e31a78 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py @@ -179,6 +179,136 @@ class VocabParallelEmbedding(paddle.nn.Layer): return output +class InnerOverlapLinear(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + fuse_matmul_bias, + mp_async_allreduce, + mp_skip_c_identity, + mp_fused_linear_param_grad_add, + model_parallel_group, + ): + ctx.save_for_backward(x, weight, bias) + ctx.model_parallel_group = model_parallel_group + ctx.mp_fused_linear_param_grad_add = mp_fused_linear_param_grad_add + if mp_skip_c_identity is False: + x = paddle._legacy_C_ops.c_identity( + x, + 'use_calc_stream', + True, + 'ring_id', + model_parallel_group.id, + 'use_model_parallel', + True, + ) + if not fuse_matmul_bias: + return paddle._C_ops.linear(x, weight, bias) + else: + return paddle._legacy_C_ops.fused_gemm_epilogue(x, weight, bias) + + @staticmethod + def backward(ctx, dy): + x, weight, bias = ctx.saved_tensor() + dx = paddle.matmul(dy, weight, transpose_y=True) + op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity") + task = ctx.model_parallel_group.process_group.all_reduce( + dx, op_type, sync_op=False + ) + # TODO(GhostScreaming): remove it in future. + tmp = paddle.ones([512]) + + if ctx.mp_fused_linear_param_grad_add: + if not is_fused_linear_param_grad_add_supported(): + raise NotImplementedError( + "You set mp_fused_linear_param_grad_add=True, " + "however, the paddle you are using not support this operation. " + "Please unset fused_linear_param_grad_add or use paddle compiled " + "with cuda 11.6 or higher." + ) + + if bias is None: + if hasattr(weight, "main_grad"): + ( + weight.main_grad, + _, + ) = paddle._C_ops.fused_linear_param_grad_add( + x, dy, weight.main_grad, None, True, False + ) + task.wait() + return dx, None + else: + if weight.grad is not None: + ( + weight.grad, + _, + ) = paddle._C_ops.fused_linear_param_grad_add( + x, dy, weight.grad, None, False, False + ) + task.wait() + return dx, None + else: + ( + dw, + _, + ) = paddle._C_ops.fused_linear_param_grad_add( + x, dy, None, None, False, False + ) + task.wait() + return dx, dw + + if hasattr(weight, "main_grad") and hasattr(bias, "main_grad"): + ( + weight.main_grad, + bias.main_grad, + ) = paddle._C_ops.fused_linear_param_grad_add( + input, + dy, + weight.main_grad, + bias.main_grad, + True, + True, + ) + task.wait() + return dx, None, None + else: + if weight.grad is not None: + assert bias.grad is not None + ( + weight.grad, + bias.grad, + ) = paddle._C_ops.fused_linear_param_grad_add( + x, dy, weight.grad, bias.grad, False, True + ) + task.wait() + return dx, None, None + else: + ( + dw, + dbias, + ) = paddle._C_ops.fused_linear_param_grad_add( + x, dy, None, None, False, True + ) + task.wait() + return dx, dw, dbias + else: + dw = paddle.matmul( + x.reshape([-1, x.shape[-1]]), + dy.reshape([-1, dy.shape[-1]]), + transpose_x=True, + ) + if bias is None: + task.wait() + return dx, dw + else: + dbias = paddle.sum(dy, axis=0) + task.wait() + return dx, dw, dbias + + class ColumnParallelLinear(paddle.nn.Layer): """Linear layer with mp parallelized(column). this class is used for splitting Linear Layer in mp group, column split the weight of the Linear layer. @@ -336,133 +466,16 @@ class ColumnParallelLinear(paddle.nn.Layer): # use inner api to process identity def _overlap_linear(): - fuse_matmul_bias = self.fuse_matmul_bias - mp_async_allreduce = self.mp_async_allreduce - mp_skip_c_identity = self.mp_skip_c_identity - mp_fused_linear_param_grad_add = self.mp_fused_linear_param_grad_add - - class InnerOverlapLinear(paddle.autograd.PyLayer): - @staticmethod - def forward(ctx, x, weight, bias): - ctx.save_for_backward(x, weight, bias) - if mp_skip_c_identity is False: - x = paddle._legacy_C_ops.c_identity( - x, - 'use_calc_stream', - True, - 'ring_id', - self.model_parallel_group.id, - 'use_model_parallel', - True, - ) - if not fuse_matmul_bias: - return paddle._C_ops.linear(x, weight, bias) - else: - return paddle._legacy_C_ops.fused_gemm_epilogue( - x, weight, bias - ) - - @staticmethod - def backward(ctx, dy): - x, weight, bias = ctx.saved_tensor() - dx = paddle.matmul(dy, weight, transpose_y=True) - op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity") - task = self.model_parallel_group.process_group.all_reduce( - dx, op_type, sync_op=False - ) - # TODO(GhostScreaming): remove it in future. - tmp = paddle.ones([512]) - - if mp_fused_linear_param_grad_add: - if not is_fused_linear_param_grad_add_supported(): - raise NotImplementedError( - "You set mp_fused_linear_param_grad_add=True, " - "however, the paddle you are using not support this operation. " - "Please unset fused_linear_param_grad_add or use paddle compiled " - "with cuda 11.6 or higher." - ) - - if bias is None: - if hasattr(weight, "main_grad"): - ( - weight.main_grad, - _, - ) = paddle._C_ops.fused_linear_param_grad_add( - x, dy, weight.main_grad, None, True, False - ) - task.wait() - return dx, None - else: - if weight.grad is not None: - ( - weight.grad, - _, - ) = paddle._C_ops.fused_linear_param_grad_add( - x, dy, weight.grad, None, False, False - ) - task.wait() - return dx, None - else: - ( - dw, - _, - ) = paddle._C_ops.fused_linear_param_grad_add( - x, dy, None, None, False, False - ) - task.wait() - return dx, dw - - if hasattr(weight, "main_grad") and hasattr( - bias, "main_grad" - ): - ( - weight.main_grad, - bias.main_grad, - ) = paddle._C_ops.fused_linear_param_grad_add( - input, - dy, - weight.main_grad, - bias.main_grad, - True, - True, - ) - task.wait() - return dx, None, None - else: - if weight.grad is not None: - assert bias.grad is not None - ( - weight.grad, - bias.grad, - ) = paddle._C_ops.fused_linear_param_grad_add( - x, dy, weight.grad, bias.grad, False, True - ) - task.wait() - return dx, None, None - else: - ( - dw, - dbias, - ) = paddle._C_ops.fused_linear_param_grad_add( - x, dy, None, None, False, True - ) - task.wait() - return dx, dw, dbias - else: - dw = paddle.matmul( - x.reshape([-1, x.shape[-1]]), - dy.reshape([-1, dy.shape[-1]]), - transpose_x=True, - ) - if bias is None: - task.wait() - return dx, dw - else: - dbias = paddle.sum(dy, axis=0) - task.wait() - return dx, dw, dbias - - return InnerOverlapLinear.apply(x, self.weight, self.bias) + return InnerOverlapLinear.apply( + x, + self.weight, + self.bias, + self.fuse_matmul_bias, + self.mp_async_allreduce, + self.mp_skip_c_identity, + self.mp_fused_linear_param_grad_add, + self.model_parallel_group, + ) if self.mp_async_allreduce: output_parallel = _overlap_linear() diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index 76692fae634ec140fae12f333723ec2f9ed0c65e..82a072fe056404c227f374c08a79f5592369e49a 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -14,6 +14,7 @@ import paddle from paddle import _legacy_C_ops +from paddle.autograd import PyLayer from paddle.distributed import collective from paddle.fluid.data_feeder import check_dtype, check_variable_and_dtype from paddle.framework import LayerHelper, _create_tensor, in_dynamic_mode @@ -23,6 +24,30 @@ from paddle.nn.utils import dygraph_utils from ....communication.reduce import ReduceOp, _get_reduce_op +class c_identity_eager(PyLayer): + @staticmethod + def forward(ctx, tensor, group, skip_c_identity_dynamic): + ctx.group = group + if skip_c_identity_dynamic: + return tensor + else: + return _legacy_C_ops.c_identity( + tensor, + 'use_calc_stream', + True, + 'ring_id', + group.id, + 'use_model_parallel', + True, + ) + + @staticmethod + def backward(ctx, dy): + op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity") + ctx.group.process_group.all_reduce_on_calc_stream(dy, op_type) + return dy + + def _c_identity(tensor, group=None, skip_c_identity_dynamic=False): """ Return a copy of the tensor, mainly used with model parallel. @@ -40,31 +65,7 @@ def _c_identity(tensor, group=None, skip_c_identity_dynamic=False): ring_id = 0 if group is None else group.id if in_dynamic_mode(): - from paddle.autograd import PyLayer - - class c_identity_eager(PyLayer): - @staticmethod - def forward(ctx, tensor): - if skip_c_identity_dynamic: - return tensor - else: - return _legacy_C_ops.c_identity( - tensor, - 'use_calc_stream', - True, - 'ring_id', - group.id, - 'use_model_parallel', - True, - ) - - @staticmethod - def backward(ctx, dy): - op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity") - group.process_group.all_reduce_on_calc_stream(dy, op_type) - return dy - - return c_identity_eager.apply(tensor) + return c_identity_eager.apply(tensor, group, skip_c_identity_dynamic) else: op_type = 'c_identity' helper = LayerHelper(op_type, **locals()) @@ -218,6 +219,49 @@ def _c_split(tensor, group=None): return out +class mp_allreduce_eager(PyLayer): + @staticmethod + def forward( + ctx, + tensor, + group, + use_calc_stream, + use_model_parallel, + op, + skip_c_identity_dynamic, + ): + ctx.ring_id = group.id + ctx.skip_c_identity_dynamic = skip_c_identity_dynamic + + if use_calc_stream: + op_type = _get_reduce_op(op, "_mp_allreduce") + group.process_group.all_reduce_on_calc_stream(tensor, op_type) + return tensor + else: + return _legacy_C_ops.c_allreduce_sum_( + tensor, + 'use_calc_stream', + use_calc_stream, + 'ring_id', + group.id, + ) + + @staticmethod + def backward(ctx, dy): + if ctx.skip_c_identity_dynamic: + return dy + else: + return _legacy_C_ops.c_identity( + dy, + 'use_calc_stream', + True, + 'ring_id', + ctx.ring_id, + 'use_model_parallel', + True, + ) + + def _mp_allreduce( tensor, op=ReduceOp.SUM, @@ -233,48 +277,13 @@ def _mp_allreduce( if in_dynamic_mode(): group = collective._get_default_group() if group is None else group assert op == ReduceOp.SUM, f"Unknown parameter: {op}." - - from paddle.autograd import PyLayer - - class mp_allreduce_eager(PyLayer): - @staticmethod - def forward( - ctx, tensor, group, use_calc_stream, use_model_parallel - ): - ctx.ring_id = group.id - - if use_calc_stream: - op_type = _get_reduce_op(op, "_mp_allreduce") - group.process_group.all_reduce_on_calc_stream( - tensor, op_type - ) - return tensor - else: - return _legacy_C_ops.c_allreduce_sum_( - tensor, - 'use_calc_stream', - use_calc_stream, - 'ring_id', - ring_id, - ) - - @staticmethod - def backward(ctx, dy): - if skip_c_identity_dynamic: - return dy - else: - return _legacy_C_ops.c_identity( - dy, - 'use_calc_stream', - True, - 'ring_id', - ctx.ring_id, - 'use_model_parallel', - True, - ) - return mp_allreduce_eager.apply( - tensor, group, use_calc_stream, use_model_parallel + tensor, + group, + use_calc_stream, + use_model_parallel, + op, + skip_c_identity_dynamic, ) else: ring_id = 0 if group is None else group.id