未验证 提交 567e2fc8 编写于 作者: S Sławomir Siwek 提交者: GitHub

[PHI] Migrate gelu kernels (#45596)

* gaussian random

* mkldnn to onednn renaming

* fix merge conflicts

* remove fluid code

* onednn renaming

* gelu fwd

* sort activations

* gelu gradient

* remove unused macros

* merge conflicts

* fix merge conflicts

* remove extra contraint from gelu op
上级 5679cdff
......@@ -510,7 +510,7 @@ function(op_library TARGET)
if(WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
# Append first implemented MKLDNN activation operator
if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(gelu, MKLDNN);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
......
......@@ -38,7 +38,7 @@ USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN);
USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP_ITSELF(gelu);
USE_OP_DEVICE_KERNEL(gelu, MKLDNN);
PD_DECLARE_KERNEL(gelu, OneDNN, ALL_LAYOUT);
PD_DECLARE_ARG_MAPPING_FN(gelu);
namespace paddle {
......
......@@ -77,8 +77,7 @@ class GeluGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn");
if (it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) {
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
......
......@@ -52,42 +52,6 @@ class MKLDNNActivationGradKernel
}
};
template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL eletwise_forward must use CPUPlace"));
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X");
auto *out = ctx.Output<Tensor>("Out");
bool is_inplaced = x->IsSharedBufferWith(*out);
platform::ActivationMKLDNNHandler<T> handler(
algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x);
auto src_memory_p = handler.AcquireSrcMemory(x);
std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
if (is_inplaced) {
dst_memory_p = src_memory_p;
out->mutable_data<T>(ctx.GetPlace());
} else {
dst_memory_p = handler.AcquireDstMemory(out);
}
auto activation_p = handler.AcquireForwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_p->execute(
astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}});
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc());
}
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) {
......@@ -116,34 +80,6 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
dx->set_mem_desc(diff_src_memory_p->get_desc());
}
template <typename T>
void eltwise_grad_use_out(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *out = ctx.Input<Tensor>("Out");
const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
platform::ActivationMKLDNNHandler<T> handler(
algorithm, ctx, mkldnn_engine, ctx.GetPlace(), out, dout);
auto dst_memory_p = handler.AcquireBackwardSrcMemory(out);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_backward_p->execute(astream,
{{DNNL_ARG_DST, *dst_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
dx->set_mem_desc(diff_src_memory_p->get_desc());
}
template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
......@@ -151,30 +87,6 @@ struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
const bool approximate = ctx.Attr<bool>("approximate");
if (approximate) {
eltwise_forward<T>(ctx, dnnl::algorithm::eltwise_gelu_tanh);
} else {
eltwise_forward<T>(ctx, dnnl::algorithm::eltwise_gelu_erf);
}
}
};
template <typename T>
struct GeluMKLDNNGradFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
const bool approximate = ctx.Attr<bool>("approximate");
if (approximate) {
eltwise_grad<T>(ctx, dnnl::algorithm::eltwise_gelu_tanh);
} else {
eltwise_grad<T>(ctx, dnnl::algorithm::eltwise_gelu_erf);
}
}
};
template <typename T>
struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
......@@ -209,6 +121,4 @@ namespace ops = paddle::operators;
ops::grad_functor<paddle::platform::bfloat16>>);
REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(softplus, SoftplusMKLDNNFunctor);
REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(gelu, GeluMKLDNNFunctor);
REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(gelu, GeluMKLDNNGradFunctor);
REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(relu6, Relu6MKLDNNGradFunctor);
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/activation_grad_kernel.h"
#include "paddle/phi/kernels/gelu_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
......@@ -23,16 +24,6 @@
namespace phi {
#define DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPX(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
DenseTensor* dx) { \
functor_class<T> functor; \
functor(dev_ctx, x, dout, 0, 0, dx); \
}
#define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX( \
name, functor_class, attr) \
template <typename T, typename Context> \
......@@ -55,18 +46,6 @@ namespace phi {
functor(dev_ctx, out, dout, 0, 0, dx); \
}
#define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT( \
name, functor_class, attr) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \
const DenseTensor& dout, \
float attr, \
DenseTensor* dx) { \
functor_class<T> functor; \
functor(dev_ctx, out, dout, attr, 0, dx); \
}
template <typename T>
void eltwise_grad(const OneDNNContext& dev_ctx,
const DenseTensor& x,
......@@ -158,12 +137,14 @@ using AbsOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_abs>;
template <typename T>
using ReluOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
using EluOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;
template <typename T>
using SwishOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_swish>;
using ExpOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;
template <typename T>
using HardSwishOneDNNGradFunctor =
......@@ -174,14 +155,21 @@ using MishOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_mish>;
template <typename T>
using SigmoidOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
using GeluTanhOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_gelu_tanh>;
template <typename T>
using TanhOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
using GeluErfOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_gelu_erf>;
template <typename T>
using ReluOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
template <typename T>
using SigmoidOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;
dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
template <typename T>
using SqrtOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
......@@ -189,22 +177,21 @@ using SqrtOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>;
template <typename T>
using EluOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;
using SwishOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_swish>;
template <typename T>
using ExpOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
using TanhOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;
dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid,
SigmoidOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Abs, AbsOneDNNGradFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluOneDNNGradFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid,
SigmoidOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu,
ReluOneDNNGradFunctor,
......@@ -215,17 +202,6 @@ DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish,
DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish,
SwishOneDNNGradFunctor,
beta);
template <typename T, typename Context>
void HardSwishGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
float threshold,
float scale,
float offset,
DenseTensor* dx) {
HardSwishOneDNNGradFunctor<T> functor;
functor(dev_ctx, x, dout, threshold, 0, dx);
}
template <typename T, typename Context>
void EluGradKernel(const Context& dev_ctx,
......@@ -238,6 +214,33 @@ void EluGradKernel(const Context& dev_ctx,
functor(dev_ctx, out, dout, alpha, 0, dx);
}
template <typename T, typename Context>
void GeluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
bool approximate,
DenseTensor* x_grad) {
if (approximate) {
GeluTanhOneDNNGradFunctor<T> functor;
functor(dev_ctx, x, out_grad, 0, 0, x_grad);
} else {
GeluErfOneDNNGradFunctor<T> functor;
functor(dev_ctx, x, out_grad, 0, 0, x_grad);
}
}
template <typename T, typename Context>
void HardSwishGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
float threshold,
float scale,
float offset,
DenseTensor* dx) {
HardSwishOneDNNGradFunctor<T> functor;
functor(dev_ctx, x, dout, threshold, 0, dx);
}
} // namespace phi
PD_REGISTER_KERNEL(relu_grad,
......@@ -254,6 +257,7 @@ PD_REGISTER_KERNEL(relu_grad,
PD_REGISTER_ACTIVATION_GRAD_KERNEL(abs_grad, AbsGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(gelu_grad, GeluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/gelu_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
......@@ -91,16 +92,18 @@ template <typename T>
using AbsOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_abs>;
template <typename T>
using ReluOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_relu>;
using EluOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>;
template <typename T>
using Relu6OneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
using ExpOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>;
template <typename T>
using SwishOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_swish>;
using GeluTanhOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_gelu_tanh>;
template <typename T>
using GeluErfOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_gelu_erf>;
template <typename T>
using HardSwishOneDNNFunctor =
......@@ -111,40 +114,46 @@ using MishOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_mish>;
template <typename T>
using SigmoidOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>;
using ReluOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_relu>;
template <typename T>
using TanhOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_tanh>;
using Relu6OneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
template <typename T>
using SqrtOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_sqrt>;
using RoundOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_round>;
template <typename T>
using EluOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>;
using SigmoidOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>;
template <typename T>
using ExpOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>;
using SqrtOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_sqrt>;
template <typename T>
using RoundOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_round>;
using SwishOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_swish>;
template <typename T>
using TanhOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_tanh>;
DEFINE_ONEDNN_ACTIVATION_KERNEL(Abs, AbsOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Relu, ReluOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Tanh, TanhOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Exp, ExpOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Sqrt, SqrtOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Relu, ReluOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Sigmoid, SigmoidOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Sqrt, SqrtOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Tanh, TanhOneDNNFunctor)
// round eltwise primitive doesn't support BF16, nor does it support grad
DEFINE_ONEDNN_ACTIVATION_KERNEL(Round, RoundOneDNNFunctor)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, ReluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishOneDNNFunctor, threshold)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Relu6, Relu6OneDNNFunctor, threshold)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishOneDNNFunctor, beta)
......@@ -159,6 +168,20 @@ void HardSwishKernel(const Context& dev_ctx,
functor(dev_ctx, x, threshold, 0, out);
}
template <typename T, typename Context>
void GeluKernel(const Context& dev_ctx,
const DenseTensor& x,
bool approximate,
DenseTensor* out) {
if (approximate) {
GeluTanhOneDNNFunctor<T> functor;
functor(dev_ctx, x, 0, 0, out);
} else {
GeluErfOneDNNFunctor<T> functor;
functor(dev_ctx, x, 0, 0, out);
}
}
} // namespace phi
PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {}
......@@ -170,6 +193,7 @@ PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {}
PD_REGISTER_ACTIVATION_KERNEL(abs, AbsKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel)
PD_REGISTER_ACTIVATION_KERNEL(gelu, GeluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册