未验证 提交 01ab8a06 编写于 作者: F Feiyu Chan 提交者: GitHub

add approximation for gelu, test=develop (#22961)

add approximation for gelu, default value is False (only kernel with eigen is added, remove code for computing gelu with MKLDNN temporarily)
上级 eec10aab
......@@ -185,13 +185,6 @@ $out = \max(x, 0)$
)DOC";
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.
$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$
)DOC";
UNUSED constexpr char TanhDoc[] = R"DOC(
Tanh Activation Operator.
......@@ -635,7 +628,6 @@ REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
......
......@@ -304,90 +304,6 @@ struct ReluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
template <typename T>
struct GeluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
// Because the execute or device context can not be deliver here, it keep the
// marco for NVCC.
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
auto x_data = x.data();
auto out_data = out.data();
int n = std::min(x.size(), out.size());
std::memset(out_data, 0, n * sizeof(T));
math::CBlas<T>::AXPY(n, static_cast<T>(M_SQRT1_2), x_data, 1, out_data, 1);
math::CBlas<T>::VMERF(n, out_data, out_data, VML_LA);
for (int i = 0; i < n; i++) {
out_data[i] += static_cast<T>(1);
}
math::CBlas<T>::VMUL(n, x_data, out_data, out_data);
for (int i = 0; i < n; i++) {
out_data[i] *= static_cast<T>(0.5);
}
#else
auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
#endif
}
};
// gelu_grad(x) = dout * (0.5 * (1 + erf(x / sqrt(2))) + 0.5 * 2 / sqrt(pi) /
// sqrt(2) * x * exp (-0.5 * x^2))
template <typename T>
struct GeluGradFunctor : BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
auto x_data = x.data();
auto dx_data = dx.data();
auto dout_data = dout.data();
int n = std::min(x.size(), dx.size());
auto first = static_cast<T*>(std::malloc(n * sizeof(T)));
std::memset(first, 0, n * sizeof(T));
auto second = static_cast<T*>(std::malloc(n * sizeof(T)));
std::memset(second, 0, n * sizeof(T));
// first = (0.5 * (1 + erf(x / sqrt(2))))
math::CBlas<T>::AXPY(n, static_cast<T>(M_SQRT1_2), x_data, 1, first, 1);
math::CBlas<T>::VMERF(n, first, first, VML_LA);
for (int i = 0; i < n; i++) {
first[i] += static_cast<T>(1);
}
math::CBlas<T>::SCAL(n, static_cast<T>(0.5), first, 1);
// second = (0.5 * 2/sqrt(pi) * 1/sqrt(2) * x * exp(-0.5 * x^2))
math::CBlas<T>::VSQUARE(n, x_data, second);
math::CBlas<T>::SCAL(n, -static_cast<T>(0.5), second, 1);
math::CBlas<T>::VEXP(n, second, second);
math::CBlas<T>::VMUL(n, x_data, second, second);
math::CBlas<T>::SCAL(n, static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2),
second, 1);
// dx = dout * (first + second);
math::CBlas<T>::VADD(n, first, second, first);
math::CBlas<T>::VMUL(n, dout_data, first, dx_data);
std::free(first);
std::free(second);
#else
auto first = static_cast<T>(0.5) *
(static_cast<T>(1) + ((x * static_cast<T>(M_SQRT1_2)).erf()));
auto second = static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x *
(-static_cast<T>(0.5) * x.square()).exp();
dx.device(d) = dout * (first + second);
#endif
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
......@@ -1727,7 +1643,6 @@ class PowGradKernel
#define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(gelu, Gelu, GeluFunctor, GeluGradFunctor); \
__macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \
__macro(atan, Atan, AtanFunctor, AtanGradFunctor); \
__macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
......
......@@ -101,6 +101,7 @@ class ErfGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("erf_grad");
grad_op->SetInput("X", this->Input("X"));
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
class GeluOp : public framework::OperatorWithKernel {
public:
GeluOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(%s) of GeluOp should not be null.", "X"));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(%s) of GeluOp should not be null.", "Out"));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class GeluGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(%s) of GeluGradOp should not be null.", "DOut"));
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(%s) of GeluGradOp should not be null.", "X"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument(
"Output(%s) of GeluGradOp should not be null.", "DX"));
auto x_grad_name = framework::GradVarName("X");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class GeluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of Gelu operator");
AddOutput("Out", "Output of Gelu operator");
AddAttr<bool>("approximate",
"(bool, default false) use approximation of gelu")
.SetDefault(false);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("use_cudnn",
"(bool, default false) Only used in cudnn kernel, need "
"install cudnn")
.SetDefault(false);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddComment(R"DOC(
Gelu Activation Operator.
For more details, please refer to [Gaussian Error Linear Units](https://arxiv.org/pdf/1606.08415.pdf).
when using approximation
$out = \\frac{1}{2}x(1+tanh(\\sqrt{\\frac{2}{\\pi}}(x+0.044715x^{3}))$
or else
$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$
)DOC");
}
};
template <typename T>
class GeluGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("gelu_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(gelu, ops::GeluOp, ops::GeluOpMaker,
ops::GeluGradOpMaker<paddle::framework::OpDesc>,
ops::GeluGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gelu_grad, ops::GeluGradOp);
REGISTER_OP_CPU_KERNEL(
gelu, ops::GeluKernel<paddle::platform::CPUDeviceContext, float>,
ops::GeluKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
gelu_grad, ops::GeluGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GeluGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
gelu, ops::GeluKernel<paddle::platform::CUDADeviceContext, float>,
ops::GeluKernel<paddle::platform::CUDADeviceContext, double>,
ops::GeluKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
gelu_grad, ops::GeluGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::GeluGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::GeluGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <algorithm>
#include <cmath>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
struct GeluFunctor {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out, bool approximate) const {
if (approximate) {
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
auto temp = (static_cast<T>(M_2_SQRTPI * M_SQRT1_2) *
(x + static_cast<T>(0.044715) * x.cube()))
.tanh();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
} else {
// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
}
}
};
template <typename T>
struct GeluGradFunctor {
template <typename Device, typename X, typename dOut, typename dX>
void operator()(Device d, X x, dOut dout, dX dx, bool approximate) const {
if (approximate) {
const T kAlpha = static_cast<T>(M_2_SQRTPI * M_SQRT1_2);
const T kBeta = kAlpha * static_cast<T>(0.044715) * static_cast<T>(3);
const auto y =
(kAlpha * ((static_cast<T>(0.044715) * x.cube()) + x)).tanh();
dx.device(d) = static_cast<T>(0.5) * dout *
(static_cast<T>(1) + y +
(x - x * y.square()) * (kAlpha + kBeta * x.square()));
} else {
// gelu_grad(x) = dout * 0.5 * (1 + erf(x / sqrt(2)) + x * sqrt(2 / pi) *
// exp(- x^2 / 2)
auto first =
static_cast<T>(0.5) *
(static_cast<T>(1) + ((x * static_cast<T>(M_SQRT1_2)).erf()));
auto second = static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x *
(-static_cast<T>(0.5) * x.square()).exp();
dx.device(d) = dout * (first + second);
}
}
};
template <typename DeviceContext, typename T>
class GeluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
auto approximate = context.Attr<bool>("approximate");
out->mutable_data<T>(in->place());
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
GeluFunctor<T> functor;
functor(place, eigen_in, eigen_out, approximate);
}
};
template <typename DeviceContext, typename T>
class GeluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* dout =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto approximate = context.Attr<bool>("approximate");
dx->mutable_data<T>(dout->place());
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_dout = framework::EigenVector<T>::Flatten(*dout);
auto eigen_dx = framework::EigenVector<T>::Flatten(*dx);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
GeluGradFunctor<T> functor;
functor(place, eigen_x, eigen_dout, eigen_dx, approximate);
}
};
} // namespace operators
} // namespace paddle
......@@ -245,7 +245,7 @@ __all__ += ['gelu']
_gelu_ = generate_layer_fn('gelu')
def gelu(x):
def gelu(x, approximate=False):
locals_var = locals().copy()
kwargs = dict()
for name, val in locals_var.items():
......@@ -259,6 +259,11 @@ gelu.__doc__ = """
For more details, see [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
Equation:
if approximate is True
.. math::
out = 0.5 * x * (1 + tanh(\\sqrt{\\frac{2}{\\pi}} * (x + 0.044715x^{3})))
else
.. math::
out = 0.5 * x * (1 + erf(\\frac{x}{\\sqrt{2}}))
......
......@@ -411,16 +411,44 @@ class TestLeakyRelu(TestActivation):
self.check_grad(['X'], 'Out')
class TestGelu(TestActivation):
def gelu(x, approximate):
if approximate:
y_ref = 0.5 * x * (1.0 + np.tanh(
np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
else:
y_ref = 0.5 * x * (1 + erf(x / np.sqrt(2)))
return y_ref.astype(x.dtype)
class TestGeluApproximate(TestActivation):
def setUp(self):
self.op_type = "gelu"
self.init_dtype()
approximate = True
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = gelu(x, approximate)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"approximate": approximate}
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
class TestGelu(TestActivation):
def setUp(self):
self.op_type = "gelu"
self.init_dtype()
approximate = False
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = 0.5 * x * (1.0 + erf(x / np.sqrt(2.0)))
out = gelu(x, approximate)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"approximate": approximate}
def test_check_grad(self):
if self.dtype == np.float16:
......
......@@ -21,33 +21,43 @@ import paddle.fluid as fluid
import paddle.fluid.dygraph as dg
def gelu(x, approximate):
if approximate:
y_ref = 0.5 * x * (1.0 + np.tanh(
np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
else:
y_ref = 0.5 * x * (1 + erf(x / np.sqrt(2)))
return y_ref.astype(x.dtype)
class TestGeluOp(unittest.TestCase):
def _test_case1_cpu(self):
def _test_case1_cpu(self, approximate):
x = np.random.uniform(-1, 1, size=(11, 17)).astype(np.float32)
y_ref = 0.5 * x * (1 + erf(x / np.sqrt(2)))
y_ref = gelu(x, approximate)
place = fluid.CPUPlace()
with dg.guard(place) as g:
x_var = dg.to_variable(x)
y_var = fluid.layers.gelu(x_var)
y_var = fluid.layers.gelu(x_var, approximate)
y_test = y_var.numpy()
self.assertTrue(np.allclose(y_ref, y_test, rtol=1e-05, atol=1e-08))
def _test_case1_gpu(self):
def _test_case1_gpu(self, approximate):
x = np.random.uniform(-1, 1, size=(11, 17)).astype(np.float32)
y_ref = 0.5 * x * (1 + erf(x / np.sqrt(2)))
y_ref = gelu(x, approximate)
place = fluid.CUDAPlace(0)
with dg.guard(place) as g:
x_var = dg.to_variable(x)
y_var = fluid.layers.gelu(x_var)
y_var = fluid.layers.gelu(x_var, approximate)
y_test = y_var.numpy()
self.assertTrue(np.allclose(y_ref, y_test, rtol=1e-05, atol=1e-08))
def test_cases(self):
self._test_case1_cpu()
if fluid.is_compiled_with_cuda():
self._test_case1_gpu()
for approximate in [True, False]:
self._test_case1_cpu(approximate)
if fluid.is_compiled_with_cuda():
self._test_case1_gpu(approximate)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册