diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index ce489352d3bcfbc7d191991bcf7f85d4be6c33ac..38b7acd93c445fe2062169da84dfc314782d536b 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -56,6 +56,12 @@ message MpConfig { optional bool sync_grad= 2 [ default = false ]; optional bool sync_moment= 3 [ default = false ]; optional string sync_mode= 4 [ default = 'broadcast' ]; + // Support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear + optional bool mp_async_allreduce= 5 [default = false ]; + // Support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when mp_async_allreduce is true. + optional bool mp_skip_c_identity= 6 [default = false ]; + // Support fused_linear_param_grad_add in ColumnParallelLinear. Only works when mp_async_allreduce is true. + optional bool mp_fused_linear_param_grad_add= 7 [default = false ]; } message PpConfig { diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py index c9a6b94576d6e3b67ff22015ae7b1d1742163c92..4112a8606a2774e6fd14f71fa2f9767a8caee827 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py @@ -14,13 +14,13 @@ import paddle from paddle.autograd import PyLayer +from paddle.distributed import fleet from paddle.fluid import core from paddle.nn import functional as F from ....communication.reduce import ReduceOp, _get_reduce_op from ...base import topology as tp from . import mp_ops -from .mp_ops import _get_mp_env_flag from .random import get_rng_state_tracker __all__ = [] @@ -305,6 +305,21 @@ class ColumnParallelLinear(paddle.nn.Layer): self.linear = F.linear self.fuse_matmul_bias = fuse_matmul_bias + + mp_configs = fleet.fleet._user_defined_strategy.hybrid_configs[ + "mp_configs" + ] + self.mp_async_allreduce = self.is_mp and mp_configs.mp_async_allreduce + self.mp_skip_c_identity = ( + self.is_mp + and mp_configs.mp_async_allreduce + and mp_configs.mp_skip_c_identity + ) + self.mp_fused_linear_param_grad_add = ( + self.is_mp + and mp_configs.mp_async_allreduce + and mp_configs.mp_fused_linear_param_grad_add + ) if self.fuse_matmul_bias: if not is_fused_matmul_bias_supported(): raise NotImplementedError( @@ -322,15 +337,15 @@ class ColumnParallelLinear(paddle.nn.Layer): def _overlap_linear(): fuse_matmul_bias = self.fuse_matmul_bias + mp_async_allreduce = self.mp_async_allreduce + mp_skip_c_identity = self.mp_skip_c_identity + mp_fused_linear_param_grad_add = self.mp_fused_linear_param_grad_add class InnerOverlapLinear(paddle.autograd.PyLayer): @staticmethod def forward(ctx, x, weight, bias): ctx.save_for_backward(x, weight, bias) - if ( - _get_mp_env_flag("Flags_mp_aysnc_allreduce") - and _get_mp_env_flag("Flags_skip_mp_c_identity") - ) is False: + if mp_skip_c_identity is False: x = paddle._legacy_C_ops.c_identity( x, 'use_calc_stream', @@ -358,12 +373,12 @@ class ColumnParallelLinear(paddle.nn.Layer): # TODO(GhostScreaming): remove it in future. tmp = paddle.ones([512]) - if _get_mp_env_flag("Flags_fused_linear_param_grad_add"): + if mp_fused_linear_param_grad_add: if not is_fused_linear_param_grad_add_supported(): raise NotImplementedError( - "You set environment variable Flags_fused_linear_param_grad_add=True, " + "You set mp_fused_linear_param_grad_add=True, " "however, the paddle you are using not support this operation. " - "Please unset Flags_fused_linear_param_grad_add or use paddle compiled " + "Please unset fused_linear_param_grad_add or use paddle compiled " "with cuda 11.6 or higher." ) @@ -449,12 +464,14 @@ class ColumnParallelLinear(paddle.nn.Layer): return InnerOverlapLinear.apply(x, self.weight, self.bias) - if _get_mp_env_flag("Flags_mp_aysnc_allreduce"): + if self.mp_async_allreduce: output_parallel = _overlap_linear() else: if self.is_mp: input_parallel = mp_ops._c_identity( - x, group=self.model_parallel_group + x, + group=self.model_parallel_group, + skip_c_identity_dynamic=self.mp_skip_c_identity, ) else: input_parallel = x @@ -570,6 +587,20 @@ class RowParallelLinear(paddle.nn.Layer): ) self.is_mp = self.world_size > 1 + mp_configs = fleet.fleet._user_defined_strategy.hybrid_configs[ + "mp_configs" + ] + self.mp_async_allreduce = self.is_mp and mp_configs.mp_async_allreduce + self.mp_skip_c_identity = ( + self.is_mp + and mp_configs.mp_async_allreduce + and mp_configs.mp_skip_c_identity + ) + self.mp_fused_linear_param_grad_add = ( + self.is_mp + and mp_configs.mp_async_allreduce + and mp_configs.mp_fused_linear_param_grad_add + ) assert in_features % self.world_size == 0, ( "Number of row of the weight for linear ({}) must be" " divisible by model parallel size ({})".format( @@ -642,6 +673,7 @@ class RowParallelLinear(paddle.nn.Layer): group=self.model_parallel_group, use_calc_stream=True, use_model_parallel=True, + skip_c_identity_dynamic=self.mp_skip_c_identity, ) else: output_parallel = self.linear( @@ -652,6 +684,7 @@ class RowParallelLinear(paddle.nn.Layer): group=self.model_parallel_group, use_calc_stream=True, use_model_parallel=True, + skip_c_identity_dynamic=self.mp_skip_c_identity, ) output = ( output_ + self.bias if self.bias is not None else output_ diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index e208bf4232f196acfe8f1ab3429be6775243fa32..76692fae634ec140fae12f333723ec2f9ed0c65e 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import paddle from paddle import _legacy_C_ops from paddle.distributed import collective @@ -24,40 +22,8 @@ from paddle.nn.utils import dygraph_utils from ....communication.reduce import ReduceOp, _get_reduce_op -_first_get_mp_env_flag = True - - -def _get_mp_env_flag(flag): - global _first_get_mp_env_flag - if _first_get_mp_env_flag: - print( - "Flags_mp_aysnc_allreduce is {}, which is used to support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear.".format( - str(os.getenv("Flags_mp_aysnc_allreduce")).lower() - ) - ) - print( - "Flags_fused_linear_param_grad_add is {}, which is used to support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format( - str(os.getenv("Flags_fused_linear_param_grad_add")).lower() - ) - ) - print( - "Flags_skip_mp_c_identity is {}, which is used to support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format( - str(os.getenv("Flags_skip_mp_c_identity")).lower() - ) - ) - # Model parallel environment flag. - # Flags_mp_aysnc_allreduce: support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear - # Flags_fused_linear_param_grad_add: support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True. - # Flags_skip_mp_c_identity: support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True. - assert flag in [ - "Flags_mp_aysnc_allreduce", - "Flags_fused_linear_param_grad_add", - "Flags_skip_mp_c_identity", - ], "Only support set Flags_mp_aysnc_allreduce (support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear), Flags_fused_linear_param_grad_add (support fused_linear_param_grad_add in ColumnParallelLinear) and Flags_skip_mp_c_identity (support skip c_identity in ColumnParallelLinear with Flags_mp_aysnc_allreduce=True, and skip c_identity in RowParallelLinear)" - return str(os.getenv(flag)).lower() in ["true", "1"] - -def _c_identity(tensor, group=None): +def _c_identity(tensor, group=None, skip_c_identity_dynamic=False): """ Return a copy of the tensor, mainly used with model parallel. @@ -79,9 +45,7 @@ def _c_identity(tensor, group=None): class c_identity_eager(PyLayer): @staticmethod def forward(ctx, tensor): - if _get_mp_env_flag( - "Flags_mp_aysnc_allreduce" - ) and _get_mp_env_flag("Flags_skip_mp_c_identity"): + if skip_c_identity_dynamic: return tensor else: return _legacy_C_ops.c_identity( @@ -260,6 +224,7 @@ def _mp_allreduce( group=None, use_calc_stream=True, use_model_parallel=True, + skip_c_identity_dynamic=False, ): """[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]""" if group is not None and not group.is_member(): @@ -295,9 +260,7 @@ def _mp_allreduce( @staticmethod def backward(ctx, dy): - if _get_mp_env_flag( - "Flags_mp_aysnc_allreduce" - ) and _get_mp_env_flag("Flags_skip_mp_c_identity"): + if skip_c_identity_dynamic: return dy else: return _legacy_C_ops.c_identity(