未验证 提交 e84250e8 编写于 作者: Y Yuang Liu 提交者: GitHub

[model parallel] enable mp to use fused linear (#44968)

上级 c91aaced
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
from paddle.fluid import core
from paddle.fluid.dygraph.layers import Layer
from .random import get_rng_state_tracker
from paddle.nn import functional as F
......@@ -27,6 +28,13 @@ __all__ = []
# language models using model parallelism[J]. arXiv preprint arXiv:1909.08053, 2019. (https://arxiv.org/abs/1909.08053)
def is_fused_matmul_bias_supported():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
return hasattr(core.ops, 'fused_gemm_epilogue')
else:
return False
class VocabParallelEmbedding(Layer):
def __init__(self,
......@@ -100,7 +108,8 @@ class ColumnParallelLinear(Layer):
weight_attr=None,
has_bias=None,
gather_output=True,
name=None):
name=None,
fuse_matmul_bias=False):
super(ColumnParallelLinear, self).__init__()
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
......@@ -147,6 +156,18 @@ class ColumnParallelLinear(Layer):
else:
self.bias = None
self.linear = F.linear
if fuse_matmul_bias:
if not is_fused_matmul_bias_supported():
raise NotImplementedError(
"You set fuse_matmul_bias=True in ColumnParallelLinear, "
"however, the paddle you are using not support this operation. "
"Please set fuse_matmul_bias=False or use paddle compiled "
"with cuda 11.6 or higher.")
from paddle.incubate.nn.functional import fused_linear
self.linear = fused_linear
def forward(self, x):
# use inner api to process identity
if self.is_mp:
......@@ -155,10 +176,10 @@ class ColumnParallelLinear(Layer):
else:
input_parallel = x
output_parallel = F.linear(input_parallel,
self.weight,
self.bias,
name=self._name)
output_parallel = self.linear(input_parallel,
self.weight,
self.bias,
name=self._name)
if self.gather_output and self.is_mp:
output = paddle.distributed.collective._c_concat(
......@@ -176,7 +197,8 @@ class RowParallelLinear(Layer):
weight_attr=None,
has_bias=True,
input_is_parallel=False,
name=None):
name=None,
fuse_matmul_bias=False):
super(RowParallelLinear, self).__init__()
self.in_features = in_features
......@@ -225,6 +247,18 @@ class RowParallelLinear(Layer):
else:
self.bias = None
self.linear = F.linear
if fuse_matmul_bias:
if not is_fused_matmul_bias_supported():
raise NotImplementedError(
"You set fuse_matmul_bias=True in RowParallelLinear, "
"however, the paddle you are using not support this operation. "
"Please set fuse_matmul_bias=False or use paddle compiled "
"with cuda 11.6 or higher.")
from paddle.incubate.nn.functional import fused_linear
self.linear = fused_linear
def forward(self, x):
if self.input_is_parallel or (not self.is_mp):
input_parallel = x
......@@ -233,18 +267,22 @@ class RowParallelLinear(Layer):
input_parallel = paddle.distributed.collective._c_split(
x, group=self.model_parallel_group)
output_parallel = F.linear(input_parallel, self.weight, name=self._name)
if self.is_mp:
output_parallel = self.linear(input_parallel,
self.weight,
name=self._name)
output_ = paddle.distributed.collective._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
output = output_ + self.bias if self.bias is not None else output_
else:
output_ = output_parallel
output = self.linear(input_parallel,
self.weight,
self.bias,
name=self._name)
output = output_ + self.bias if self.bias is not None else output_
return output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册