未验证 提交 1b8619c7 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim][NewIR] Add test case prim custom vjp in NewIR (#57030)

* test prim custom vjp in New IR

* polish gelu_grad
上级 1642e84b
...@@ -63,13 +63,14 @@ VJPS = [ ...@@ -63,13 +63,14 @@ VJPS = [
'transpose_grad', 'transpose_grad',
'dropout_grad', 'dropout_grad',
] ]
VJP_COMPS = ['divide_grad', 'sum_grad'] VJP_COMPS = ['divide_grad', 'sum_grad', 'gelu_grad']
BACKENDS = [ BACKENDS = [
'add_n', 'add_n',
'mean', 'mean',
'sum', 'sum',
'divide', 'divide',
'full', 'full',
'tanh',
'tanh_grad', 'tanh_grad',
'mean_grad', 'mean_grad',
'concat', 'concat',
......
...@@ -125,6 +125,44 @@ void sum_grad(const Tensor& x, ...@@ -125,6 +125,44 @@ void sum_grad(const Tensor& x,
set_output<T>(x_grad_tmp, x_grad); set_output<T>(x_grad_tmp, x_grad);
} }
template <typename T>
void gelu_grad(const Tensor& x,
const Tensor& out_grad,
bool approximate,
Tensor* x_grad) {
if (!x_grad) return;
// Promote to fp32 when the input type is fp16 for keeping consistent with
// phi kernel
// Scale only support fp32 attr in static graph mode, use elementwise_xx
// when precision is over fp32.
if (approximate) {
auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
auto kKappa = 0.044715;
auto x_sq = x * x;
auto x_cube = x_sq * x;
auto inner = kBeta * (x + kKappa * x_cube);
auto tanh_inner = tanh<T>(inner);
auto left = scale<T>(x, 0.5);
auto right = scale<T>(tanh_inner, 1., 1.);
auto left_derivative = scale<T>(right, 0.5);
auto tanh_derivative = scale<T>(tanh_inner * tanh_inner, -1., 1.);
auto inner_derivative = kBeta * (scale<T>(3 * kKappa * x_sq, 1., 1.));
auto right_derivative = left * tanh_derivative * inner_derivative;
set_output<T>(out_grad * (left_derivative + right_derivative), x_grad);
} else {
auto kAlpha = M_SQRT1_2;
auto kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5;
auto cdf = scale<T>(scale<T>(erf<T>(kAlpha * x), 1., 1.), 0.5);
auto pdf = kBeta * exp<T>(scale<T>(x * x, -0.5));
set_output<T>(out_grad * (cdf + x * pdf), x_grad);
}
}
} // namespace details } // namespace details
} // namespace primitive } // namespace primitive
} // namespace paddle } // namespace paddle
set(TEST_PRIM_PURE_NEW_IR_CASES test_prim_program test_prim_simpnet) set(TEST_PRIM_PURE_NEW_IR_CASES test_prim_program test_prim_simpnet
test_prim_custom_vjp)
foreach(target ${TEST_PRIM_PURE_NEW_IR_CASES}) foreach(target ${TEST_PRIM_PURE_NEW_IR_CASES})
py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1 py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle import _ir_ops, nn
from paddle.autograd.ir_backward import grad
from paddle.decomposition import decompose
from paddle.framework import core
paddle.enable_static()
class SimpNet(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, x, linear1_weight, linear2_weight):
x2 = _ir_ops.matmul(x, linear1_weight, False, False)
x3 = _ir_ops.gelu(x2, False)
res = _ir_ops.matmul(x3, linear2_weight, False, False)
return res
class TestPrimMode(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
self.shape_x = [2, 1024, 1024]
self.shape_y = [2, 1024, 1024]
self.shape_l1_w = [2, 1024, 4096]
self.shape_l2_w = [2, 4096, 1024]
self.x = np.random.random(self.shape_x).astype("float32")
self.y = np.random.random(self.shape_y).astype("float32")
self.l1_w = np.random.random(self.shape_l1_w).astype("float32")
self.l2_w = np.random.random(self.shape_l2_w).astype("float32")
def base_net(self, flag=None):
if flag == "backward":
core._set_prim_backward_enabled(True)
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
net = SimpNet()
x = paddle.static.data('x', self.shape_x, dtype='float32')
y = paddle.static.data('y', self.shape_y, dtype='float32')
x.stop_gradient = False
y.stop_gradient = False
l1_w = paddle.static.data('l1_w', self.shape_l1_w, dtype='float32')
l2_w = paddle.static.data('l2_w', self.shape_l2_w, dtype='float32')
divide_out = paddle.divide(x, y)
res = net(divide_out, l1_w, l2_w)
[res2] = decompose(
main_program,
[res],
)
gradients = grad(res2, (x, y))
if flag == "backward":
whole_ops_before = [
op.name() for op in main_program.block().ops
]
assert (
"pd.gelu" in whole_ops_before
and "pd.gelu_grad" not in whole_ops_before
)
core._set_prim_forward_enabled(True)
[res2] = decompose(main_program, [res2], whitelist={"pd.gelu"})
whole_ops_after = [op.name() for op in main_program.block().ops]
assert "pd.gelu" not in whole_ops_after
core._set_prim_forward_enabled(False)
exe = paddle.static.Executor()
outs = exe.run(
feed={
'x': self.x,
'y': self.y,
'l1_w': self.l1_w,
'l2_w': self.l2_w,
},
fetch_list=[res2, gradients[0], gradients[1]],
)
if flag == "backward":
core._set_prim_backward_enabled(False)
return outs
def test_prim_custom_vjp(self):
res_ref = self.base_net()
res = self.base_net("backward")
for ref, actual in zip(res_ref, res):
np.testing.assert_allclose(ref, actual, rtol=1e-6)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册