From 1b8619c768ca796005ea0f3ede29b5fd34e83b89 Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Fri, 8 Sep 2023 14:30:52 +0800 Subject: [PATCH] [Prim][NewIR] Add test case prim custom vjp in NewIR (#57030) * test prim custom vjp in New IR * polish gelu_grad --- paddle/fluid/primitive/codegen/gen.py | 3 +- paddle/fluid/primitive/rule/vjp/details.h | 38 ++++++ test/prim/new_ir_prim/CMakeLists.txt | 3 +- test/prim/new_ir_prim/test_prim_custom_vjp.py | 108 ++++++++++++++++++ 4 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 test/prim/new_ir_prim/test_prim_custom_vjp.py diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 722ed94953d..43eb5005f0f 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -63,13 +63,14 @@ VJPS = [ 'transpose_grad', 'dropout_grad', ] -VJP_COMPS = ['divide_grad', 'sum_grad'] +VJP_COMPS = ['divide_grad', 'sum_grad', 'gelu_grad'] BACKENDS = [ 'add_n', 'mean', 'sum', 'divide', 'full', + 'tanh', 'tanh_grad', 'mean_grad', 'concat', diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 12fb66127a2..eb640a4643e 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -125,6 +125,44 @@ void sum_grad(const Tensor& x, set_output(x_grad_tmp, x_grad); } +template +void gelu_grad(const Tensor& x, + const Tensor& out_grad, + bool approximate, + Tensor* x_grad) { + if (!x_grad) return; + // Promote to fp32 when the input type is fp16 for keeping consistent with + // phi kernel + + // Scale only support fp32 attr in static graph mode, use elementwise_xx + // when precision is over fp32. + if (approximate) { + auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; + auto kKappa = 0.044715; + auto x_sq = x * x; + auto x_cube = x_sq * x; + auto inner = kBeta * (x + kKappa * x_cube); + auto tanh_inner = tanh(inner); + + auto left = scale(x, 0.5); + auto right = scale(tanh_inner, 1., 1.); + + auto left_derivative = scale(right, 0.5); + + auto tanh_derivative = scale(tanh_inner * tanh_inner, -1., 1.); + auto inner_derivative = kBeta * (scale(3 * kKappa * x_sq, 1., 1.)); + auto right_derivative = left * tanh_derivative * inner_derivative; + + set_output(out_grad * (left_derivative + right_derivative), x_grad); + } else { + auto kAlpha = M_SQRT1_2; + auto kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5; + auto cdf = scale(scale(erf(kAlpha * x), 1., 1.), 0.5); + auto pdf = kBeta * exp(scale(x * x, -0.5)); + set_output(out_grad * (cdf + x * pdf), x_grad); + } +} + } // namespace details } // namespace primitive } // namespace paddle diff --git a/test/prim/new_ir_prim/CMakeLists.txt b/test/prim/new_ir_prim/CMakeLists.txt index 72fe311f270..1b37b432d20 100644 --- a/test/prim/new_ir_prim/CMakeLists.txt +++ b/test/prim/new_ir_prim/CMakeLists.txt @@ -1,4 +1,5 @@ -set(TEST_PRIM_PURE_NEW_IR_CASES test_prim_program test_prim_simpnet) +set(TEST_PRIM_PURE_NEW_IR_CASES test_prim_program test_prim_simpnet + test_prim_custom_vjp) foreach(target ${TEST_PRIM_PURE_NEW_IR_CASES}) py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1 diff --git a/test/prim/new_ir_prim/test_prim_custom_vjp.py b/test/prim/new_ir_prim/test_prim_custom_vjp.py new file mode 100644 index 00000000000..6cd0527ff64 --- /dev/null +++ b/test/prim/new_ir_prim/test_prim_custom_vjp.py @@ -0,0 +1,108 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import _ir_ops, nn +from paddle.autograd.ir_backward import grad +from paddle.decomposition import decompose +from paddle.framework import core + +paddle.enable_static() + + +class SimpNet(nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x, linear1_weight, linear2_weight): + x2 = _ir_ops.matmul(x, linear1_weight, False, False) + x3 = _ir_ops.gelu(x2, False) + res = _ir_ops.matmul(x3, linear2_weight, False, False) + return res + + +class TestPrimMode(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.shape_x = [2, 1024, 1024] + self.shape_y = [2, 1024, 1024] + self.shape_l1_w = [2, 1024, 4096] + self.shape_l2_w = [2, 4096, 1024] + self.x = np.random.random(self.shape_x).astype("float32") + self.y = np.random.random(self.shape_y).astype("float32") + self.l1_w = np.random.random(self.shape_l1_w).astype("float32") + self.l2_w = np.random.random(self.shape_l2_w).astype("float32") + + def base_net(self, flag=None): + if flag == "backward": + core._set_prim_backward_enabled(True) + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + net = SimpNet() + x = paddle.static.data('x', self.shape_x, dtype='float32') + y = paddle.static.data('y', self.shape_y, dtype='float32') + x.stop_gradient = False + y.stop_gradient = False + l1_w = paddle.static.data('l1_w', self.shape_l1_w, dtype='float32') + l2_w = paddle.static.data('l2_w', self.shape_l2_w, dtype='float32') + divide_out = paddle.divide(x, y) + res = net(divide_out, l1_w, l2_w) + [res2] = decompose( + main_program, + [res], + ) + gradients = grad(res2, (x, y)) + + if flag == "backward": + whole_ops_before = [ + op.name() for op in main_program.block().ops + ] + assert ( + "pd.gelu" in whole_ops_before + and "pd.gelu_grad" not in whole_ops_before + ) + core._set_prim_forward_enabled(True) + [res2] = decompose(main_program, [res2], whitelist={"pd.gelu"}) + whole_ops_after = [op.name() for op in main_program.block().ops] + assert "pd.gelu" not in whole_ops_after + core._set_prim_forward_enabled(False) + + exe = paddle.static.Executor() + outs = exe.run( + feed={ + 'x': self.x, + 'y': self.y, + 'l1_w': self.l1_w, + 'l2_w': self.l2_w, + }, + fetch_list=[res2, gradients[0], gradients[1]], + ) + + if flag == "backward": + core._set_prim_backward_enabled(False) + return outs + + def test_prim_custom_vjp(self): + res_ref = self.base_net() + res = self.base_net("backward") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + + +if __name__ == "__main__": + unittest.main() -- GitLab