diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index d5ef54d7b7b0386acf30d9476f69d6de5b6f8e03..207a4fcb036c3e50fe9119ad23b8cf8b7eb3cae4 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -15,7 +15,11 @@ from .fused_transformer import fused_multi_head_attention from .fused_transformer import fused_feedforward from .fused_transformer import fused_multi_transformer -from .fused_matmul_bias import fused_matmul_bias, fused_linear +from .fused_matmul_bias import ( + fused_matmul_bias, + fused_linear, + fused_linear_activation, +) from .fused_transformer import fused_bias_dropout_residual_layer_norm from .fused_ec_moe import fused_ec_moe from .fused_dropout_add import fused_dropout_add @@ -23,13 +27,13 @@ from .fused_gate_attention import fused_gate_attention from .fused_rotary_position_embedding import fused_rotary_position_embedding from .rms_norm import rms_norm - __all__ = [ 'fused_multi_head_attention', 'fused_feedforward', 'fused_multi_transformer', 'fused_matmul_bias', 'fused_linear', + 'fused_linear_activation', 'fused_bias_dropout_residual_layer_norm', 'fused_ec_moe', 'fused_dropout_add', diff --git a/python/paddle/incubate/nn/functional/fused_matmul_bias.py b/python/paddle/incubate/nn/functional/fused_matmul_bias.py index 72fbdea4535edcdfc52c3b375db77f5f9321a066..0fbb63025e3cf84d25c150d0fa65c90e9f4a7cbe 100644 --- a/python/paddle/incubate/nn/functional/fused_matmul_bias.py +++ b/python/paddle/incubate/nn/functional/fused_matmul_bias.py @@ -99,3 +99,63 @@ def fused_linear(x, weight, bias=None, transpose_weight=False, name=None): print(out.shape) # [3, 5] """ return fused_matmul_bias(x, weight, bias, False, transpose_weight, name) + + +def fused_linear_activation( + x, y, bias, trans_x=False, trans_y=False, activation=None +): + """ + Fully-connected linear and activation transformation operator. This method requires CUDA version >= 11.6. + + Args: + x (Tensor): the input Tensor to be multiplied. + weight (Tensor): the weight Tensor to be multiplied. Its rank must be 2. + bias (Tensor): the input bias Tensor, the bias is added to the matrix multiplication result. + transpose_weight (bool): Whether to transpose :math:`weight` before multiplication. + activation(str|None): Activation function, Currently, the available activation functions are limited to "gelu" (Gaussian Error Linear Unit) and "relu" (Rectified Linear Unit). These activation functions are applied to the output of the bias add. + Returns: + Tensor: the output Tensor. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + from paddle.incubate.nn.functional import fused_linear_activation + + x = paddle.randn([3, 4]) + weight = paddle.randn([4, 5]) + bias = paddle.randn([5]) + out = fused_linear_activation(x, weight, bias) + print(out.shape) # [3, 5] + """ + if activation is None: + activation = "none" + + if in_dynamic_mode(): + return _legacy_C_ops.fused_gemm_epilogue( + x, + y, + bias, + 'trans_x', + trans_x, + 'trans_y', + trans_y, + 'activation', + activation, + ) + + helper = LayerHelper('fused_matmul_bias', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='fused_gemm_epilogue', + inputs={'X': x, 'Y': y, 'Bias': bias}, + outputs={'Out': out}, + attrs={ + 'trans_x': trans_x, + 'trans_y': trans_y, + 'activation': activation, + }, + ) + + return out diff --git a/test/legacy_test/test_fused_gemm_epilogue_op.py b/test/legacy_test/test_fused_gemm_epilogue_op.py index 49095020279d6b3f79cbc0e8faa692c2612b941a..5aeae6671882ecb05fa5c92ae27ee1bcfbfc766e 100644 --- a/test/legacy_test/test_fused_gemm_epilogue_op.py +++ b/test/legacy_test/test_fused_gemm_epilogue_op.py @@ -20,6 +20,7 @@ from eager_op_test import OpTest, skip_check_grad_ci, skip_check_inplace_ci import paddle from paddle.fluid import core +from paddle.incubate.nn.functional import fused_linear_activation def is_fused_gemm_epilogue_supported(): @@ -587,13 +588,15 @@ class TestEagerFusedGemmEpilogue(unittest.TestCase): x.stop_gradient = False y.stop_gradient = False - out1 = core.eager.ops.fused_gemm_epilogue( + out1 = fused_linear_activation( x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'none' ) - out2 = core.eager.ops.fused_gemm_epilogue( + + out2 = fused_linear_activation( x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'relu' ) - out3 = core.eager.ops.fused_gemm_epilogue( + + out3 = fused_linear_activation( x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'gelu' )