From 4ed6f3bc01902e11dc7794e807b7c0401b9a6d7e Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Thu, 1 Sep 2022 15:40:38 +0800 Subject: [PATCH] add gelu and erf primitive operators for new autograd (#45338) * add erf_p primitive operators * add gelu orig2prim rule --- .../fluid/operators/prim_ops/CMakeLists.txt | 6 +- paddle/fluid/operators/prim_ops/erf_p_op.cc | 78 +++++++++++++++++++ .../fluid/operators/prim_ops/prim_op_test.cc | 20 +++++ .../autograd/test_jvp_and_transpose.py | 36 +++++++++ .../unittests/autograd/test_orig2prim.py | 65 ++++++++++++++++ .../unittests/autograd/test_prim2orig.py | 20 +++++ .../tests/unittests/autograd/test_primapi.py | 55 +++++++++---- .../tests/unittests/autograd/test_primops.py | 1 + python/paddle/incubate/autograd/primops.py | 5 ++ python/paddle/incubate/autograd/primrules.py | 48 +++++++++++- 10 files changed, 313 insertions(+), 21 deletions(-) create mode 100644 paddle/fluid/operators/prim_ops/erf_p_op.cc diff --git a/paddle/fluid/operators/prim_ops/CMakeLists.txt b/paddle/fluid/operators/prim_ops/CMakeLists.txt index 1651dae2d0f..1f63d5d1721 100644 --- a/paddle/fluid/operators/prim_ops/CMakeLists.txt +++ b/paddle/fluid/operators/prim_ops/CMakeLists.txt @@ -22,13 +22,17 @@ set(PRIM_OP_SRCS div_p_op.cc sqrt_p_op.cc tanh_p_op.cc + sin_p_op.cc + cos_p_op.cc + exp_p_op.cc matmul_p_op.cc fill_constant_p_op.cc log_p_op.cc select_p_op.cc eq_p_op.cc pow_p_op.cc - max_p_op.cc) + max_p_op.cc + erf_p_op.cc) cc_test( prim_op_test diff --git a/paddle/fluid/operators/prim_ops/erf_p_op.cc b/paddle/fluid/operators/prim_ops/erf_p_op.cc new file mode 100644 index 00000000000..21bfdf1fd4d --- /dev/null +++ b/paddle/fluid/operators/prim_ops/erf_p_op.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class ErfPrimOp : public framework::OperatorBase { + public: + ErfPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator erf_p should not be excuted directly")); + } +}; + +class ErfPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of erf_p op."); + AddOutput("Y", "(Tensor), The output tensor of erf_p op."); + AddComment(R"DOC(Autograd primitive erf_p operator.)DOC"); + } +}; + +class ErfPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); + } +}; + +class ErfPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(erf_p, + paddle::operators::ErfPrimOp, + paddle::operators::ErfPrimOpMaker, + paddle::operators::ErfPrimOpShapeInference, + paddle::operators::ErfPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/prim_op_test.cc b/paddle/fluid/operators/prim_ops/prim_op_test.cc index f3a74138abb..44872f9060b 100644 --- a/paddle/fluid/operators/prim_ops/prim_op_test.cc +++ b/paddle/fluid/operators/prim_ops/prim_op_test.cc @@ -39,6 +39,7 @@ USE_OP_ITSELF(select_p); USE_OP_ITSELF(eq_p); USE_OP_ITSELF(pow_p); USE_OP_ITSELF(max_p); +USE_OP_ITSELF(erf_p); namespace paddle { namespace framework { @@ -710,5 +711,24 @@ TEST(PrimOp, max_p) { ASSERT_EQ(shapes[2], 4L); } +TEST(PrimOp, erf_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 4, 5}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + + NewVar(block, x0, shape); + AppendOp(block, "erf_p", {{"X", {x0}}}, {{"Y", {x1}}}, {}); + ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x1")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + } // namespace framework } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py index 2c23da54970..51104223f95 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py @@ -364,6 +364,42 @@ class TestExpPJVPAndTranspose(TestAddPJVPAndTranspose): ] +class TestErfPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'erf_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + self.prim_input = { + 'X': X, + } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + self.all_ops = [ + # prim op: + 'erf_p', + # jvp op: + 'exp_p', + 'fill_constant_p', + 'fill_constant_p', + 'fill_constant_p', + 'mul_p', + 'mul_p', + 'pow_p', + 'sub_p', + # transpose op: + ] + + class TestLogPJVPAndTranspose(TestAddPJVPAndTranspose): def init_data(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py index ce9c64fbbed..c9f1aa6c41a 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -208,6 +208,26 @@ class TestExpOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestErfOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'erf' + X = paddle.static.data(name='X', shape=[3, 4], dtype='float') + + self.input = { + 'X': X, + } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, ) + self.all_ops = ['erf', 'erf_p'] + self.out_map = {0: self.output['Out']} + + class TestLogOrig2Prim(TestElementWiseAddOrig2Prim): def init_data(self): @@ -559,5 +579,50 @@ class TestMaxOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestGeluOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'gelu' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'approximate': False} + + self.orig2prim_args = (X, ) + self.all_ops = [ + 'gelu', 'add_p', 'erf_p', 'fill_constant_p', 'fill_constant_p', + 'fill_constant_p', 'mul_p', 'mul_p', 'mul_p' + ] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestGeluApproximateOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'gelu' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'approximate': True} + + self.orig2prim_args = (X, ) + self.all_ops = [ + 'add_p', 'add_p', 'fill_constant_p', 'fill_constant_p', + 'fill_constant_p', 'fill_constant_p', 'fill_constant_p', 'gelu', + 'mul_p', 'mul_p', 'mul_p', 'mul_p', 'pow_p', 'tanh_p' + ] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py index 53120ce742a..4d0f1500736 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py @@ -224,6 +224,26 @@ class TestExpPPrim2Orig(TestAddPPrim2Orig): self.out_map = {self.output['Y']: 0} +class TestErfPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'erf_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + + self.input = { + 'X': X, + } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, ) + self.all_ops = ['erf_p', 'erf'] + self.out_map = {self.output['Y']: 0} + + class TestLogPPrim2Orig(TestAddPPrim2Orig): def init_data(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index 5bd21e41904..7edbb6ef77f 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -16,11 +16,11 @@ import typing import unittest import numpy as np -import autograd -import autograd.numpy as np_autograd - import paddle +import autograd +import autograd.numpy as anp +import autograd.scipy as ascipy import config import utils @@ -278,6 +278,11 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y) np.array([1, 2, 3]), np.array([2, 2, 2]), ), None, 'float32'), + ('erf', paddle.erf, (np.random.rand(300, 288), ), None, 'float32'), + ('gelu', paddle.nn.functional.gelu, + (np.random.rand(200, 189), ), None, 'float32'), + ('gelu_approximate', lambda x: paddle.nn.functional.gelu(x, True), + (np.random.rand(200, 189), ), None, 'float32'), )) class TestGrad(unittest.TestCase): @@ -397,25 +402,41 @@ def multiply_pd(x): multiply_ag = lambda xs: xs[0] * xs[0] * xs[0] * xs[0] * xs[0] -sin_ag = lambda xs: np_autograd.sin(xs[0]) -cos_ag = lambda xs: np_autograd.cos(xs[0]) -exp_ag = lambda xs: np_autograd.exp(xs[0]) +sin_ag = lambda xs: anp.sin(xs[0]) +cos_ag = lambda xs: anp.cos(xs[0]) +exp_ag = lambda xs: anp.exp(xs[0]) pow_ag = lambda xs: xs[0]**xs[1] -log_ag = lambda xs: np_autograd.log(xs[0]) +log_ag = lambda xs: anp.log(xs[0]) +erf_ag = lambda xs: ascipy.special.erf(xs[0]) + + +def gelu_ag(x, approximate=False): + if approximate: + sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype) + cdf = 0.5 * (1.0 + anp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x**3)))) + return x * cdf + else: + return x * (ascipy.special.erf(x / np.sqrt(2)) + 1) / 2 @utils.place(config.DEVICES) @utils.parameterize( - (utils.TEST_CASE_NAME, 'fun_pd', 'fun_ag', 'xs', 'v', 'dtype'), ( - ('multiply', multiply_pd, multiply_ag, - (np.random.rand(3, 5), ), None, 'float32'), - ('sin', paddle.sin, sin_ag, (np.random.rand(2, 3), ), None, 'float32'), - ('cos', paddle.cos, cos_ag, (np.random.rand(3, 4), ), None, 'float32'), - ('exp', paddle.exp, exp_ag, (np.random.rand(2, 3), ), None, 'float32'), - ('pow', paddle.pow, pow_ag, - (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'), - ('log', paddle.log, log_ag, (np.random.rand(3, 8), ), None, 'float32'), - )) + (utils.TEST_CASE_NAME, 'fun_pd', 'fun_ag', 'xs', 'v', 'dtype'), + (('multiply', multiply_pd, multiply_ag, + (np.random.rand(3, 5), ), None, 'float32'), + ('sin', paddle.sin, sin_ag, (np.random.rand(2, 3), ), None, 'float32'), + ('cos', paddle.cos, cos_ag, (np.random.rand(3, 4), ), None, 'float32'), + ('exp', paddle.exp, exp_ag, (np.random.rand(2, 3), ), None, 'float32'), + ('pow', paddle.pow, pow_ag, + (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'), + ('log', paddle.log, log_ag, (np.random.rand(3, 8), ), None, 'float32'), + ('erf', paddle.erf, erf_ag, (np.random.rand(100, 200), ), None, 'float32'), + ('gelu', paddle.nn.functional.gelu, lambda xs: gelu_ag(xs[0]), + (np.random.rand(10, 20, 30), ), None, 'float32'), + ('gelu_approximate', + lambda x: paddle.nn.functional.gelu(x, approximate=True), + lambda xs: gelu_ag(xs[0], approximate=True), + (np.random.rand(10, 20, 30), ), None, 'float32'))) class TestGradWithHigherOrder(unittest.TestCase): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primops.py b/python/paddle/fluid/tests/unittests/autograd/test_primops.py index f1396ce69f9..25a3d9bce23 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primops.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primops.py @@ -41,6 +41,7 @@ paddle.enable_static() ('sin', primops.sin, randn(2, 3), {}, (2, 3), 'float64'), ('cos', primops.cos, randn(2, 3), {}, (2, 3), 'float64'), ('exp', primops.exp, randn(2, 3), {}, (2, 3), 'float64'), + ('erf', primops.erf, randn(2, 3), {}, (2, 3), 'float64'), ('log', primops.log, randn(2, 3), {}, (2, 3), 'float64'), ('reshape', primops.reshape, randn(2, 3), { 'shape': (3, 2) diff --git a/python/paddle/incubate/autograd/primops.py b/python/paddle/incubate/autograd/primops.py index d29647f9404..cb67d3f23e9 100644 --- a/python/paddle/incubate/autograd/primops.py +++ b/python/paddle/incubate/autograd/primops.py @@ -355,3 +355,8 @@ def pow(x, y, out=None): @REGISTER_FN('max_p', 'X', 'Y', 'Z') def max(x, y, out=None): return _simple_binop(LayerHelper('max_p', **locals())) + + +@REGISTER_FN('erf_p', 'X', 'Y') +def erf(x, out=None): + return _simple_unop(LayerHelper('erf_p', **locals())) diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index bfcfcfb9a4f..9e14c863330 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import typing +import math import paddle @@ -19,7 +20,7 @@ from . import primops from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather, matmul, mul, neg, reduce, reshape, scatter_add, set_value, sin, slice_assign, slice_select, split, sqrt, sub, tanh, - transpose, log, select, eq, max) + transpose, log, select, eq, max, erf) from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, REGISTER_TRANSPOSE, lookup_fn, lookup_jvp, lookup_orig2prim, lookup_prim2orig, lookup_transpose, @@ -171,6 +172,11 @@ def exp_orig2prim(op, x): return exp(x) +@REGISTER_ORIG2PRIM('erf') +def erf_orig2prim(op, x): + return erf(x) + + @REGISTER_ORIG2PRIM('log') def log_orig2prim(op, x): return log(x) @@ -321,13 +327,34 @@ def elementwise_pow_orig2prim(op, x, y): def elementwise_max_orig2prim(op, x, y): if x.shape != y.shape: y = broadcast(y, shape=x.shape) - return primops.max(x, y) -## Register prim2orig lower rules +@REGISTER_ORIG2PRIM('gelu') +def gelu_orig2prim(op, x): + if op.attr('approximate'): + cdf = mul( + fill_const(0.5, x.shape, x.dtype), + add( + fill_const(1.0, x.shape, x.dtype), + tanh( + mul( + fill_const(math.sqrt(2 / math.pi), x.shape, x.dtype), + add( + x, + mul( + fill_const(0.044715, x.shape, x.dtype), + primops.pow(x, fill_const(3., x.shape, + x.dtype)))))))) + return mul(x, cdf) + else: + return mul( + mul(fill_const(0.5, x.shape, x.dtype), x), + add(fill_const(1.0, x.shape, x.dtype), + erf(mul(x, fill_const(1 / math.sqrt(2.), x.shape, x.dtype))))) +## Register prim2orig lower rules @REGISTER_PRIM2ORIG('add_p') def add_prim2orig(op, x, y): return paddle.add(x, y) @@ -373,6 +400,11 @@ def exp_prim2orig(op, x): return paddle.exp(x) +@REGISTER_PRIM2ORIG('erf_p') +def erf_prim2orig(op, x): + return paddle.erf(x) + + @REGISTER_PRIM2ORIG('log_p') def log_prim2orig(op, x): return paddle.log(x) @@ -574,6 +606,16 @@ def exp_jvp(op, x_dot): return mul(x_dot, y) +@REGISTER_JVP('erf_p') +def erf_jvp(op, x_dot): + if x_dot is None: + return None + x, = op_position_inputs(op) + return mul( + fill_const(2. / math.sqrt(math.pi), x.shape, x.dtype), + mul(x_dot, exp(neg(primops.pow(x, fill_const(2., x.shape, x.dtype)))))) + + @REGISTER_JVP('log_p') def log_jvp(op, x_dot): if x_dot is None: -- GitLab