未验证 提交 fa084e5e 编写于 作者: N niuliling123 提交者: GitHub

Add fuse_linear_activation (#55420)

上级 30f059d6
...@@ -15,7 +15,11 @@ ...@@ -15,7 +15,11 @@
from .fused_transformer import fused_multi_head_attention from .fused_transformer import fused_multi_head_attention
from .fused_transformer import fused_feedforward from .fused_transformer import fused_feedforward
from .fused_transformer import fused_multi_transformer from .fused_transformer import fused_multi_transformer
from .fused_matmul_bias import fused_matmul_bias, fused_linear from .fused_matmul_bias import (
fused_matmul_bias,
fused_linear,
fused_linear_activation,
)
from .fused_transformer import fused_bias_dropout_residual_layer_norm from .fused_transformer import fused_bias_dropout_residual_layer_norm
from .fused_ec_moe import fused_ec_moe from .fused_ec_moe import fused_ec_moe
from .fused_dropout_add import fused_dropout_add from .fused_dropout_add import fused_dropout_add
...@@ -23,13 +27,13 @@ from .fused_gate_attention import fused_gate_attention ...@@ -23,13 +27,13 @@ from .fused_gate_attention import fused_gate_attention
from .fused_rotary_position_embedding import fused_rotary_position_embedding from .fused_rotary_position_embedding import fused_rotary_position_embedding
from .rms_norm import rms_norm from .rms_norm import rms_norm
__all__ = [ __all__ = [
'fused_multi_head_attention', 'fused_multi_head_attention',
'fused_feedforward', 'fused_feedforward',
'fused_multi_transformer', 'fused_multi_transformer',
'fused_matmul_bias', 'fused_matmul_bias',
'fused_linear', 'fused_linear',
'fused_linear_activation',
'fused_bias_dropout_residual_layer_norm', 'fused_bias_dropout_residual_layer_norm',
'fused_ec_moe', 'fused_ec_moe',
'fused_dropout_add', 'fused_dropout_add',
......
...@@ -99,3 +99,63 @@ def fused_linear(x, weight, bias=None, transpose_weight=False, name=None): ...@@ -99,3 +99,63 @@ def fused_linear(x, weight, bias=None, transpose_weight=False, name=None):
print(out.shape) # [3, 5] print(out.shape) # [3, 5]
""" """
return fused_matmul_bias(x, weight, bias, False, transpose_weight, name) return fused_matmul_bias(x, weight, bias, False, transpose_weight, name)
def fused_linear_activation(
x, y, bias, trans_x=False, trans_y=False, activation=None
):
"""
Fully-connected linear and activation transformation operator. This method requires CUDA version >= 11.6.
Args:
x (Tensor): the input Tensor to be multiplied.
weight (Tensor): the weight Tensor to be multiplied. Its rank must be 2.
bias (Tensor): the input bias Tensor, the bias is added to the matrix multiplication result.
transpose_weight (bool): Whether to transpose :math:`weight` before multiplication.
activation(str|None): Activation function, Currently, the available activation functions are limited to "gelu" (Gaussian Error Linear Unit) and "relu" (Rectified Linear Unit). These activation functions are applied to the output of the bias add.
Returns:
Tensor: the output Tensor.
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.functional import fused_linear_activation
x = paddle.randn([3, 4])
weight = paddle.randn([4, 5])
bias = paddle.randn([5])
out = fused_linear_activation(x, weight, bias)
print(out.shape) # [3, 5]
"""
if activation is None:
activation = "none"
if in_dynamic_mode():
return _legacy_C_ops.fused_gemm_epilogue(
x,
y,
bias,
'trans_x',
trans_x,
'trans_y',
trans_y,
'activation',
activation,
)
helper = LayerHelper('fused_matmul_bias', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='fused_gemm_epilogue',
inputs={'X': x, 'Y': y, 'Bias': bias},
outputs={'Out': out},
attrs={
'trans_x': trans_x,
'trans_y': trans_y,
'activation': activation,
},
)
return out
...@@ -20,6 +20,7 @@ from eager_op_test import OpTest, skip_check_grad_ci, skip_check_inplace_ci ...@@ -20,6 +20,7 @@ from eager_op_test import OpTest, skip_check_grad_ci, skip_check_inplace_ci
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.incubate.nn.functional import fused_linear_activation
def is_fused_gemm_epilogue_supported(): def is_fused_gemm_epilogue_supported():
...@@ -587,13 +588,15 @@ class TestEagerFusedGemmEpilogue(unittest.TestCase): ...@@ -587,13 +588,15 @@ class TestEagerFusedGemmEpilogue(unittest.TestCase):
x.stop_gradient = False x.stop_gradient = False
y.stop_gradient = False y.stop_gradient = False
out1 = core.eager.ops.fused_gemm_epilogue( out1 = fused_linear_activation(
x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'none' x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'none'
) )
out2 = core.eager.ops.fused_gemm_epilogue(
out2 = fused_linear_activation(
x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'relu' x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'relu'
) )
out3 = core.eager.ops.fused_gemm_epilogue(
out3 = fused_linear_activation(
x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'gelu' x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'gelu'
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册