diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py index 55bec32bb1a5cfdfdafe27f5472194520488a089..c9a6b94576d6e3b67ff22015ae7b1d1742163c92 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py @@ -17,8 +17,10 @@ from paddle.autograd import PyLayer from paddle.fluid import core from paddle.nn import functional as F +from ....communication.reduce import ReduceOp, _get_reduce_op from ...base import topology as tp from . import mp_ops +from .mp_ops import _get_mp_env_flag from .random import get_rng_state_tracker __all__ = [] @@ -32,6 +34,13 @@ def is_fused_matmul_bias_supported(): return hasattr(core.eager.ops.legacy, 'fused_gemm_epilogue') +def is_fused_linear_param_grad_add_supported(): + if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(): + return hasattr(paddle._C_ops, 'fused_linear_param_grad_add') + else: + return False + + class VocabParallelEmbedding(paddle.nn.Layer): """Embedding mp parallelized in the vocabulary dimension. this class is used for splitting embedding in mp group. @@ -295,7 +304,8 @@ class ColumnParallelLinear(paddle.nn.Layer): self.linear = F.linear - if fuse_matmul_bias: + self.fuse_matmul_bias = fuse_matmul_bias + if self.fuse_matmul_bias: if not is_fused_matmul_bias_supported(): raise NotImplementedError( "You set fuse_matmul_bias=True in ColumnParallelLinear, " @@ -309,16 +319,149 @@ class ColumnParallelLinear(paddle.nn.Layer): def forward(self, x): # use inner api to process identity - if self.is_mp: - input_parallel = mp_ops._c_identity( - x, group=self.model_parallel_group - ) + + def _overlap_linear(): + fuse_matmul_bias = self.fuse_matmul_bias + + class InnerOverlapLinear(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, weight, bias): + ctx.save_for_backward(x, weight, bias) + if ( + _get_mp_env_flag("Flags_mp_aysnc_allreduce") + and _get_mp_env_flag("Flags_skip_mp_c_identity") + ) is False: + x = paddle._legacy_C_ops.c_identity( + x, + 'use_calc_stream', + True, + 'ring_id', + self.model_parallel_group.id, + 'use_model_parallel', + True, + ) + if not fuse_matmul_bias: + return paddle._C_ops.linear(x, weight, bias) + else: + return paddle._legacy_C_ops.fused_gemm_epilogue( + x, weight, bias + ) + + @staticmethod + def backward(ctx, dy): + x, weight, bias = ctx.saved_tensor() + dx = paddle.matmul(dy, weight, transpose_y=True) + op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity") + task = self.model_parallel_group.process_group.all_reduce( + dx, op_type, sync_op=False + ) + # TODO(GhostScreaming): remove it in future. + tmp = paddle.ones([512]) + + if _get_mp_env_flag("Flags_fused_linear_param_grad_add"): + if not is_fused_linear_param_grad_add_supported(): + raise NotImplementedError( + "You set environment variable Flags_fused_linear_param_grad_add=True, " + "however, the paddle you are using not support this operation. " + "Please unset Flags_fused_linear_param_grad_add or use paddle compiled " + "with cuda 11.6 or higher." + ) + + if bias is None: + if hasattr(weight, "main_grad"): + ( + weight.main_grad, + _, + ) = paddle._C_ops.fused_linear_param_grad_add( + x, dy, weight.main_grad, None, True, False + ) + task.wait() + return dx, None + else: + if weight.grad is not None: + ( + weight.grad, + _, + ) = paddle._C_ops.fused_linear_param_grad_add( + x, dy, weight.grad, None, False, False + ) + task.wait() + return dx, None + else: + ( + dw, + _, + ) = paddle._C_ops.fused_linear_param_grad_add( + x, dy, None, None, False, False + ) + task.wait() + return dx, dw + + if hasattr(weight, "main_grad") and hasattr( + bias, "main_grad" + ): + ( + weight.main_grad, + bias.main_grad, + ) = paddle._C_ops.fused_linear_param_grad_add( + input, + dy, + weight.main_grad, + bias.main_grad, + True, + True, + ) + task.wait() + return dx, None, None + else: + if weight.grad is not None: + assert bias.grad is not None + ( + weight.grad, + bias.grad, + ) = paddle._C_ops.fused_linear_param_grad_add( + x, dy, weight.grad, bias.grad, False, True + ) + task.wait() + return dx, None, None + else: + ( + dw, + dbias, + ) = paddle._C_ops.fused_linear_param_grad_add( + x, dy, None, None, False, True + ) + task.wait() + return dx, dw, dbias + else: + dw = paddle.matmul( + x.reshape([-1, x.shape[-1]]), + dy.reshape([-1, dy.shape[-1]]), + transpose_x=True, + ) + if bias is None: + task.wait() + return dx, dw + else: + dbias = paddle.sum(dy, axis=0) + task.wait() + return dx, dw, dbias + + return InnerOverlapLinear.apply(x, self.weight, self.bias) + + if _get_mp_env_flag("Flags_mp_aysnc_allreduce"): + output_parallel = _overlap_linear() else: - input_parallel = x + if self.is_mp: + input_parallel = mp_ops._c_identity( + x, group=self.model_parallel_group + ) + else: + input_parallel = x - output_parallel = self.linear( - input_parallel, self.weight, self.bias, name=self._name - ) + output_parallel = self.linear( + input_parallel, self.weight, self.bias, name=self._name + ) if self.gather_output and self.is_mp: output = mp_ops._c_concat( diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index 322281d3c9f65d80139652a851277c6fcd8eb4b0..a7ada15627da40cdec40671dce77735626adc486 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import paddle from paddle import _legacy_C_ops from paddle.distributed import collective @@ -22,6 +24,38 @@ from paddle.nn.utils import dygraph_utils from ....communication.reduce import ReduceOp, _get_reduce_op +_first_get_mp_env_flag = True + + +def _get_mp_env_flag(flag): + global _first_get_mp_env_flag + if _first_get_mp_env_flag: + print( + "Flags_mp_aysnc_allreduce is {}, which is used to support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear.".format( + str(os.getenv("Flags_mp_aysnc_allreduce")).lower() + ) + ) + print( + "Flags_fused_linear_param_grad_add is {}, which is used to support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format( + str(os.getenv("Flags_fused_linear_param_grad_add")).lower() + ) + ) + print( + "Flags_skip_mp_c_identity is {}, which is used to support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format( + str(os.getenv("Flags_skip_mp_c_identity")).lower() + ) + ) + # Model parallel environment flag. + # Flags_mp_aysnc_allreduce: support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear + # Flags_fused_linear_param_grad_add: support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True. + # Flags_skip_mp_c_identity: support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True. + assert flag in [ + "Flags_mp_aysnc_allreduce", + "Flags_fused_linear_param_grad_add", + "Flags_skip_mp_c_identity", + ], "Only support set Flags_mp_aysnc_allreduce (support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear), Flags_fused_linear_param_grad_add (support fused_linear_param_grad_add in ColumnParallelLinear) and Flags_skip_mp_c_identity (support skip c_identity in ColumnParallelLinear with Flags_mp_aysnc_allreduce=True, and skip c_identity in RowParallelLinear)" + return str(os.getenv(flag)).lower() in ["true", "1"] + def _c_identity(tensor, group=None): """ @@ -45,15 +79,20 @@ def _c_identity(tensor, group=None): class c_identity_eager(PyLayer): @staticmethod def forward(ctx, tensor): - return _legacy_C_ops.c_identity( - tensor, - 'use_calc_stream', - True, - 'ring_id', - group.id, - 'use_model_parallel', - True, - ) + if _get_mp_env_flag( + "Flags_mp_aysnc_allreduce" + ) and _get_mp_env_flag("Flags_skip_mp_c_identity"): + return tensor + else: + return _legacy_C_ops.c_identity( + tensor, + 'use_calc_stream', + True, + 'ring_id', + group.id, + 'use_model_parallel', + True, + ) @staticmethod def backward(ctx, dy): @@ -256,15 +295,20 @@ def _mp_allreduce( @staticmethod def backward(ctx, dy): - return _legacy_C_ops.c_identity( - dy, - 'use_calc_stream', - True, - 'ring_id', - ctx.ring_id, - 'use_model_parallel', - True, - ) + if _get_mp_env_flag( + "Flags_mp_aysnc_allreduce" + ) and _get_mp_env_flag("Flags_skip_mp_c_identity"): + return dy + else: + return _legacy_C_ops.c_identity( + dy, + 'use_calc_stream', + True, + 'ring_id', + ctx.ring_id, + 'use_model_parallel', + True, + ) return mp_allreduce_eager.apply( tensor, group, use_calc_stream, use_model_parallel