diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py index 4fc85a07511f6d6640a93fe63088193d6751cc8b..cb130ae0b236541916cd2b7eb71b8abfa8ad0ea2 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -30,6 +30,13 @@ vjp_interface_declare_gen_op_list = [ "add", "concat", "split", + "gelu", + "matmul", + "erf", + "multiply", + "subtract", + "pow", + "rsqrt", ] vjp_interface_implementation_gen_op_list = [ "tanh", @@ -38,4 +45,11 @@ vjp_interface_implementation_gen_op_list = [ "add", "concat", "split", + "gelu", + "matmul", + "erf", + "multiply", + "subtract", + "pow", + "rsqrt", ] diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc index 64cb1d69b210a7295752fed4fa662d9d04de5344..33469fef8fa32fc8725ae1ff56f5f135f9d21ade 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc @@ -39,8 +39,8 @@ OpInfoTuple AddNOp::GetOpInfo() { std::vector attributes = {}; std::vector outputs = { OpOutputInfo("out", "paddle::dialect::DenseTensorType", false, false)}; - paddle::dialect::OpRunTimeInfo run_time_info = - OpRunTimeInfo("", {""}, {""}, {""}, {""}, {}, {}, {}); + paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo( + "AddNInferMeta", {"inputs"}, {"add_n"}, {"inputs"}, {}, {}, {}, {}); return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n"); } diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h index fe9beb46012edf06f8601fa5a0ef2178ed8ac0ff..b0ff45d9baaff8c94626a917aecd4e7f4685b627 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h @@ -34,7 +34,7 @@ paddle::dialect::AddNOp, paddle::dialect::SplitGradOp namespace paddle { namespace dialect { -class AddNOp : public ir::Op { +class AddNOp : public ir::Op { public: using Op::Op; static const char *name() { return "pd.add_n"; } diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index 940249d0ae5265daab6aaffad736162bb1319c46..c166c131d10cba5547124d3c0f99c2a73fdd6753 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -169,11 +169,14 @@ def decompose( dst_vars, op_filter, ) - for item in dst_vars: + for idx, item in enumerate(dst_vars): if not isinstance(item, ir.OpResult): - raise TypeError( - f"Each var in dst_vars should map corresponding var in src_vars, but got type {type(item)} in {dst_vars}." - ) + if item is None: + dst_vars[idx] = src_vars[idx] + else: + raise TypeError( + f"Each var in dst_vars should map corresponding var in src_vars, but got type {type(item)} in {dst_vars}." + ) logging.debug( "Decompose composite forward ops finish: {}".format( core.prim_config["composite_ops_record"] diff --git a/python/paddle/decomposition/rules.py b/python/paddle/decomposition/rules.py index 4184dbdbea62e4eedf2dd47b6de147e4591b75f3..ef225ce461382a9b6e2c00fd759a2bc841bc4278 100644 --- a/python/paddle/decomposition/rules.py +++ b/python/paddle/decomposition/rules.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from paddle import _ir_ops + from .primitives import * # noqa: F403 from .register import register_decomp @@ -34,3 +36,40 @@ def mean(x, axis, keepdim): ) res = divide(sum_x, norm) return res + + +@register_decomp('pd.gelu') +def gelu_composite(x, approximate): + """define composite rule of op gelu""" + M_SQRT1_2 = ( + 0.70710678118654752440 # /* 1/sqrt(2) */ copy from gelu-kernel.cc + ) + M_2_SQRTPI = 1.12837916709551257390 # /* 2/sqrt(pi) */ + full_shape = x.shape if len(x.shape) == 0 else [1] + one = ones(full_shape, x.dtype) + half = full(full_shape, 0.5, x.dtype) + # Todo(cz): after symbol overload, add and multiply will be replaced by "+" and "*" + if approximate: + # gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3}))) + kAlpha = full(full_shape, M_2_SQRTPI * M_SQRT1_2, x.dtype) + GELU_CONSTANT = full(full_shape, 0.044715, x.dtype) + tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x)) + out = x * half * (one + tanh_out) + return out + + else: + # gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + + cdf = _ir_ops.multiply( + half, + ( + _ir_ops.add( + one, + _ir_ops.erf( + _ir_ops.multiply(x, full(x.shape, M_SQRT1_2, x.dtype)) + ), + ) + ), + ) + out = _ir_ops.multiply(x, cdf) + return out diff --git a/test/prim/new_ir_prim/CMakeLists.txt b/test/prim/new_ir_prim/CMakeLists.txt index 85611d846cbbe09920b3ad6f6a3a53e686c249d8..72fe311f270e65c7b0290969aa2944b2853588c1 100644 --- a/test/prim/new_ir_prim/CMakeLists.txt +++ b/test/prim/new_ir_prim/CMakeLists.txt @@ -1,4 +1,4 @@ -set(TEST_PRIM_PURE_NEW_IR_CASES test_prim_program) +set(TEST_PRIM_PURE_NEW_IR_CASES test_prim_program test_prim_simpnet) foreach(target ${TEST_PRIM_PURE_NEW_IR_CASES}) py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1 diff --git a/test/prim/new_ir_prim/test_prim_simpnet.py b/test/prim/new_ir_prim/test_prim_simpnet.py new file mode 100644 index 0000000000000000000000000000000000000000..505152354b9868f3bb81a14b204355a4c63d8adf --- /dev/null +++ b/test/prim/new_ir_prim/test_prim_simpnet.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import _ir_ops, nn +from paddle.autograd.ir_backward import grad +from paddle.decomposition import decompose +from paddle.framework import core + +paddle.enable_static() + + +class SimpNet(nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x, linear1_weight, linear2_weight): + x2 = _ir_ops.matmul(x, linear1_weight, False, False) + x3 = _ir_ops.gelu(x2, False) + res = _ir_ops.matmul(x3, linear2_weight, False, False) + return res + + +class TestPrimMode(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.shape_x = [2, 1024, 1024] + self.shape_y = [2, 1024, 1024] + self.shape_l1_w = [2, 1024, 4096] + self.shape_l2_w = [2, 4096, 1024] + self.x = np.random.random(self.shape_x).astype("float32") + self.y = np.random.random(self.shape_y).astype("float32") + self.l1_w = np.random.random(self.shape_l1_w).astype("float32") + self.l2_w = np.random.random(self.shape_l2_w).astype("float32") + + def base_net(self, flag=None): + if flag == "all": + core._set_prim_all_enabled(True) + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + net = SimpNet() + x = paddle.static.data('x', self.shape_x, dtype='float32') + y = paddle.static.data('y', self.shape_y, dtype='float32') + x.stop_gradient = False + y.stop_gradient = False + l1_w = paddle.static.data('l1_w', self.shape_l1_w, dtype='float32') + l2_w = paddle.static.data('l2_w', self.shape_l2_w, dtype='float32') + divide_out = paddle.divide(x, y) + res = net(divide_out, l1_w, l2_w) + [res2] = decompose(main_program, [res]) + gradients = grad(res2, (x, y)) + exe = paddle.static.Executor() + outs = exe.run( + feed={ + 'x': self.x, + 'y': self.y, + 'l1_w': self.l1_w, + 'l2_w': self.l2_w, + }, + fetch_list=[res2, gradients[0], gradients[1]], + ) + + whole_ops = [op.name() for op in main_program.block().ops] + if flag == "all": + core._set_prim_all_enabled(False) + assert ( + 'pd.gelu' not in whole_ops and 'pd.divide_grad' not in whole_ops + ) + return outs + + def test_prim_all(self): + res_ref = self.base_net() + res = self.base_net("all") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + + +if __name__ == "__main__": + unittest.main()