未验证 提交 4ed6f3bc 编写于 作者: X Xiaoxu Chen 提交者: GitHub

add gelu and erf primitive operators for new autograd (#45338)

* add erf_p primitive operators

* add gelu orig2prim rule
上级 1a0ef45e
......@@ -22,13 +22,17 @@ set(PRIM_OP_SRCS
div_p_op.cc
sqrt_p_op.cc
tanh_p_op.cc
sin_p_op.cc
cos_p_op.cc
exp_p_op.cc
matmul_p_op.cc
fill_constant_p_op.cc
log_p_op.cc
select_p_op.cc
eq_p_op.cc
pow_p_op.cc
max_p_op.cc)
max_p_op.cc
erf_p_op.cc)
cc_test(
prim_op_test
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class ErfPrimOp : public framework::OperatorBase {
public:
ErfPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator erf_p should not be excuted directly"));
}
};
class ErfPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of erf_p op.");
AddOutput("Y", "(Tensor), The output tensor of erf_p op.");
AddComment(R"DOC(Autograd primitive erf_p operator.)DOC");
}
};
class ErfPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};
class ErfPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(erf_p,
paddle::operators::ErfPrimOp,
paddle::operators::ErfPrimOpMaker,
paddle::operators::ErfPrimOpShapeInference,
paddle::operators::ErfPrimOpVarTypeInference);
......@@ -39,6 +39,7 @@ USE_OP_ITSELF(select_p);
USE_OP_ITSELF(eq_p);
USE_OP_ITSELF(pow_p);
USE_OP_ITSELF(max_p);
USE_OP_ITSELF(erf_p);
namespace paddle {
namespace framework {
......@@ -710,5 +711,24 @@ TEST(PrimOp, max_p) {
ASSERT_EQ(shapes[2], 4L);
}
TEST(PrimOp, erf_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
std::string x0 = "x0";
std::string x1 = "x1";
NewVar(block, x0, shape);
AppendOp(block, "erf_p", {{"X", {x0}}}, {{"Y", {x1}}}, {});
ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x1")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
} // namespace framework
} // namespace paddle
......@@ -364,6 +364,42 @@ class TestExpPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class TestErfPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'erf_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
self.all_ops = [
# prim op:
'erf_p',
# jvp op:
'exp_p',
'fill_constant_p',
'fill_constant_p',
'fill_constant_p',
'mul_p',
'mul_p',
'pow_p',
'sub_p',
# transpose op:
]
class TestLogPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
......
......@@ -208,6 +208,26 @@ class TestExpOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestErfOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'erf'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')
self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['erf', 'erf_p']
self.out_map = {0: self.output['Out']}
class TestLogOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
......@@ -559,5 +579,50 @@ class TestMaxOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestGeluOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'gelu'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'approximate': False}
self.orig2prim_args = (X, )
self.all_ops = [
'gelu', 'add_p', 'erf_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'mul_p', 'mul_p', 'mul_p'
]
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestGeluApproximateOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'gelu'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'approximate': True}
self.orig2prim_args = (X, )
self.all_ops = [
'add_p', 'add_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'fill_constant_p', 'fill_constant_p', 'gelu',
'mul_p', 'mul_p', 'mul_p', 'mul_p', 'pow_p', 'tanh_p'
]
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
if __name__ == '__main__':
unittest.main()
......@@ -224,6 +224,26 @@ class TestExpPPrim2Orig(TestAddPPrim2Orig):
self.out_map = {self.output['Y']: 0}
class TestErfPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'erf_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {
'X': X,
}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, )
self.all_ops = ['erf_p', 'erf']
self.out_map = {self.output['Y']: 0}
class TestLogPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
......
......@@ -16,11 +16,11 @@ import typing
import unittest
import numpy as np
import autograd
import autograd.numpy as np_autograd
import paddle
import autograd
import autograd.numpy as anp
import autograd.scipy as ascipy
import config
import utils
......@@ -278,6 +278,11 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y)
np.array([1, 2, 3]),
np.array([2, 2, 2]),
), None, 'float32'),
('erf', paddle.erf, (np.random.rand(300, 288), ), None, 'float32'),
('gelu', paddle.nn.functional.gelu,
(np.random.rand(200, 189), ), None, 'float32'),
('gelu_approximate', lambda x: paddle.nn.functional.gelu(x, True),
(np.random.rand(200, 189), ), None, 'float32'),
))
class TestGrad(unittest.TestCase):
......@@ -397,25 +402,41 @@ def multiply_pd(x):
multiply_ag = lambda xs: xs[0] * xs[0] * xs[0] * xs[0] * xs[0]
sin_ag = lambda xs: np_autograd.sin(xs[0])
cos_ag = lambda xs: np_autograd.cos(xs[0])
exp_ag = lambda xs: np_autograd.exp(xs[0])
sin_ag = lambda xs: anp.sin(xs[0])
cos_ag = lambda xs: anp.cos(xs[0])
exp_ag = lambda xs: anp.exp(xs[0])
pow_ag = lambda xs: xs[0]**xs[1]
log_ag = lambda xs: np_autograd.log(xs[0])
log_ag = lambda xs: anp.log(xs[0])
erf_ag = lambda xs: ascipy.special.erf(xs[0])
def gelu_ag(x, approximate=False):
if approximate:
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
cdf = 0.5 * (1.0 + anp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x**3))))
return x * cdf
else:
return x * (ascipy.special.erf(x / np.sqrt(2)) + 1) / 2
@utils.place(config.DEVICES)
@utils.parameterize(
(utils.TEST_CASE_NAME, 'fun_pd', 'fun_ag', 'xs', 'v', 'dtype'), (
('multiply', multiply_pd, multiply_ag,
(np.random.rand(3, 5), ), None, 'float32'),
('sin', paddle.sin, sin_ag, (np.random.rand(2, 3), ), None, 'float32'),
('cos', paddle.cos, cos_ag, (np.random.rand(3, 4), ), None, 'float32'),
('exp', paddle.exp, exp_ag, (np.random.rand(2, 3), ), None, 'float32'),
('pow', paddle.pow, pow_ag,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
('log', paddle.log, log_ag, (np.random.rand(3, 8), ), None, 'float32'),
))
(utils.TEST_CASE_NAME, 'fun_pd', 'fun_ag', 'xs', 'v', 'dtype'),
(('multiply', multiply_pd, multiply_ag,
(np.random.rand(3, 5), ), None, 'float32'),
('sin', paddle.sin, sin_ag, (np.random.rand(2, 3), ), None, 'float32'),
('cos', paddle.cos, cos_ag, (np.random.rand(3, 4), ), None, 'float32'),
('exp', paddle.exp, exp_ag, (np.random.rand(2, 3), ), None, 'float32'),
('pow', paddle.pow, pow_ag,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
('log', paddle.log, log_ag, (np.random.rand(3, 8), ), None, 'float32'),
('erf', paddle.erf, erf_ag, (np.random.rand(100, 200), ), None, 'float32'),
('gelu', paddle.nn.functional.gelu, lambda xs: gelu_ag(xs[0]),
(np.random.rand(10, 20, 30), ), None, 'float32'),
('gelu_approximate',
lambda x: paddle.nn.functional.gelu(x, approximate=True),
lambda xs: gelu_ag(xs[0], approximate=True),
(np.random.rand(10, 20, 30), ), None, 'float32')))
class TestGradWithHigherOrder(unittest.TestCase):
def setUp(self):
......
......@@ -41,6 +41,7 @@ paddle.enable_static()
('sin', primops.sin, randn(2, 3), {}, (2, 3), 'float64'),
('cos', primops.cos, randn(2, 3), {}, (2, 3), 'float64'),
('exp', primops.exp, randn(2, 3), {}, (2, 3), 'float64'),
('erf', primops.erf, randn(2, 3), {}, (2, 3), 'float64'),
('log', primops.log, randn(2, 3), {}, (2, 3), 'float64'),
('reshape', primops.reshape, randn(2, 3), {
'shape': (3, 2)
......
......@@ -355,3 +355,8 @@ def pow(x, y, out=None):
@REGISTER_FN('max_p', 'X', 'Y', 'Z')
def max(x, y, out=None):
return _simple_binop(LayerHelper('max_p', **locals()))
@REGISTER_FN('erf_p', 'X', 'Y')
def erf(x, out=None):
return _simple_unop(LayerHelper('erf_p', **locals()))
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
import math
import paddle
......@@ -19,7 +20,7 @@ from . import primops
from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather,
matmul, mul, neg, reduce, reshape, scatter_add, set_value,
sin, slice_assign, slice_select, split, sqrt, sub, tanh,
transpose, log, select, eq, max)
transpose, log, select, eq, max, erf)
from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG,
REGISTER_TRANSPOSE, lookup_fn, lookup_jvp,
lookup_orig2prim, lookup_prim2orig, lookup_transpose,
......@@ -171,6 +172,11 @@ def exp_orig2prim(op, x):
return exp(x)
@REGISTER_ORIG2PRIM('erf')
def erf_orig2prim(op, x):
return erf(x)
@REGISTER_ORIG2PRIM('log')
def log_orig2prim(op, x):
return log(x)
......@@ -321,13 +327,34 @@ def elementwise_pow_orig2prim(op, x, y):
def elementwise_max_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
return primops.max(x, y)
## Register prim2orig lower rules
@REGISTER_ORIG2PRIM('gelu')
def gelu_orig2prim(op, x):
if op.attr('approximate'):
cdf = mul(
fill_const(0.5, x.shape, x.dtype),
add(
fill_const(1.0, x.shape, x.dtype),
tanh(
mul(
fill_const(math.sqrt(2 / math.pi), x.shape, x.dtype),
add(
x,
mul(
fill_const(0.044715, x.shape, x.dtype),
primops.pow(x, fill_const(3., x.shape,
x.dtype))))))))
return mul(x, cdf)
else:
return mul(
mul(fill_const(0.5, x.shape, x.dtype), x),
add(fill_const(1.0, x.shape, x.dtype),
erf(mul(x, fill_const(1 / math.sqrt(2.), x.shape, x.dtype)))))
## Register prim2orig lower rules
@REGISTER_PRIM2ORIG('add_p')
def add_prim2orig(op, x, y):
return paddle.add(x, y)
......@@ -373,6 +400,11 @@ def exp_prim2orig(op, x):
return paddle.exp(x)
@REGISTER_PRIM2ORIG('erf_p')
def erf_prim2orig(op, x):
return paddle.erf(x)
@REGISTER_PRIM2ORIG('log_p')
def log_prim2orig(op, x):
return paddle.log(x)
......@@ -574,6 +606,16 @@ def exp_jvp(op, x_dot):
return mul(x_dot, y)
@REGISTER_JVP('erf_p')
def erf_jvp(op, x_dot):
if x_dot is None:
return None
x, = op_position_inputs(op)
return mul(
fill_const(2. / math.sqrt(math.pi), x.shape, x.dtype),
mul(x_dot, exp(neg(primops.pow(x, fill_const(2., x.shape, x.dtype))))))
@REGISTER_JVP('log_p')
def log_jvp(op, x_dot):
if x_dot is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册