From e84250e8b49267e092f58ebe66a9373a9c940122 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 9 Aug 2022 10:30:49 +0800 Subject: [PATCH] [model parallel] enable mp to use fused linear (#44968) --- .../parallel_layers/mp_layers.py | 58 +++++++++++++++---- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py index 14ca1322e78..c9d7c71dbbb 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py @@ -13,6 +13,7 @@ # limitations under the License. import paddle +from paddle.fluid import core from paddle.fluid.dygraph.layers import Layer from .random import get_rng_state_tracker from paddle.nn import functional as F @@ -27,6 +28,13 @@ __all__ = [] # language models using model parallelism[J]. arXiv preprint arXiv:1909.08053, 2019. (https://arxiv.org/abs/1909.08053) +def is_fused_matmul_bias_supported(): + if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(): + return hasattr(core.ops, 'fused_gemm_epilogue') + else: + return False + + class VocabParallelEmbedding(Layer): def __init__(self, @@ -100,7 +108,8 @@ class ColumnParallelLinear(Layer): weight_attr=None, has_bias=None, gather_output=True, - name=None): + name=None, + fuse_matmul_bias=False): super(ColumnParallelLinear, self).__init__() self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( @@ -147,6 +156,18 @@ class ColumnParallelLinear(Layer): else: self.bias = None + self.linear = F.linear + + if fuse_matmul_bias: + if not is_fused_matmul_bias_supported(): + raise NotImplementedError( + "You set fuse_matmul_bias=True in ColumnParallelLinear, " + "however, the paddle you are using not support this operation. " + "Please set fuse_matmul_bias=False or use paddle compiled " + "with cuda 11.6 or higher.") + from paddle.incubate.nn.functional import fused_linear + self.linear = fused_linear + def forward(self, x): # use inner api to process identity if self.is_mp: @@ -155,10 +176,10 @@ class ColumnParallelLinear(Layer): else: input_parallel = x - output_parallel = F.linear(input_parallel, - self.weight, - self.bias, - name=self._name) + output_parallel = self.linear(input_parallel, + self.weight, + self.bias, + name=self._name) if self.gather_output and self.is_mp: output = paddle.distributed.collective._c_concat( @@ -176,7 +197,8 @@ class RowParallelLinear(Layer): weight_attr=None, has_bias=True, input_is_parallel=False, - name=None): + name=None, + fuse_matmul_bias=False): super(RowParallelLinear, self).__init__() self.in_features = in_features @@ -225,6 +247,18 @@ class RowParallelLinear(Layer): else: self.bias = None + self.linear = F.linear + + if fuse_matmul_bias: + if not is_fused_matmul_bias_supported(): + raise NotImplementedError( + "You set fuse_matmul_bias=True in RowParallelLinear, " + "however, the paddle you are using not support this operation. " + "Please set fuse_matmul_bias=False or use paddle compiled " + "with cuda 11.6 or higher.") + from paddle.incubate.nn.functional import fused_linear + self.linear = fused_linear + def forward(self, x): if self.input_is_parallel or (not self.is_mp): input_parallel = x @@ -233,18 +267,22 @@ class RowParallelLinear(Layer): input_parallel = paddle.distributed.collective._c_split( x, group=self.model_parallel_group) - output_parallel = F.linear(input_parallel, self.weight, name=self._name) - if self.is_mp: + output_parallel = self.linear(input_parallel, + self.weight, + name=self._name) output_ = paddle.distributed.collective._mp_allreduce( output_parallel, group=self.model_parallel_group, use_calc_stream=True, use_model_parallel=True) + output = output_ + self.bias if self.bias is not None else output_ else: - output_ = output_parallel + output = self.linear(input_parallel, + self.weight, + self.bias, + name=self._name) - output = output_ + self.bias if self.bias is not None else output_ return output -- GitLab