未验证 提交 567e2fc8 编写于 作者: S Sławomir Siwek 提交者: GitHub

[PHI] Migrate gelu kernels (#45596)

* gaussian random

* mkldnn to onednn renaming

* fix merge conflicts

* remove fluid code

* onednn renaming

* gelu fwd

* sort activations

* gelu gradient

* remove unused macros

* merge conflicts

* fix merge conflicts

* remove extra contraint from gelu op
上级 5679cdff
...@@ -510,7 +510,7 @@ function(op_library TARGET) ...@@ -510,7 +510,7 @@ function(op_library TARGET)
if(WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) if(WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
# Append first implemented MKLDNN activation operator # Append first implemented MKLDNN activation operator
if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(gelu, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op") elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file} file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n") "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
......
...@@ -38,7 +38,7 @@ USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN); ...@@ -38,7 +38,7 @@ USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP_ITSELF(gelu); USE_OP_ITSELF(gelu);
USE_OP_DEVICE_KERNEL(gelu, MKLDNN); PD_DECLARE_KERNEL(gelu, OneDNN, ALL_LAYOUT);
PD_DECLARE_ARG_MAPPING_FN(gelu); PD_DECLARE_ARG_MAPPING_FN(gelu);
namespace paddle { namespace paddle {
......
...@@ -77,8 +77,7 @@ class GeluGradOp : public framework::OperatorWithKernel { ...@@ -77,8 +77,7 @@ class GeluGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn"); if (this->CanMKLDNNBeUsed(ctx, data_type)) {
if (it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type, return framework::OpKernelType(data_type,
ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
......
...@@ -52,42 +52,6 @@ class MKLDNNActivationGradKernel ...@@ -52,42 +52,6 @@ class MKLDNNActivationGradKernel
} }
}; };
template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL eletwise_forward must use CPUPlace"));
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X");
auto *out = ctx.Output<Tensor>("Out");
bool is_inplaced = x->IsSharedBufferWith(*out);
platform::ActivationMKLDNNHandler<T> handler(
algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x);
auto src_memory_p = handler.AcquireSrcMemory(x);
std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
if (is_inplaced) {
dst_memory_p = src_memory_p;
out->mutable_data<T>(ctx.GetPlace());
} else {
dst_memory_p = handler.AcquireDstMemory(out);
}
auto activation_p = handler.AcquireForwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_p->execute(
astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}});
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc());
}
template <typename T> template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx, void eltwise_grad(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) { dnnl::algorithm algorithm) {
...@@ -116,34 +80,6 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -116,34 +80,6 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
dx->set_mem_desc(diff_src_memory_p->get_desc()); dx->set_mem_desc(diff_src_memory_p->get_desc());
} }
template <typename T>
void eltwise_grad_use_out(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *out = ctx.Input<Tensor>("Out");
const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
platform::ActivationMKLDNNHandler<T> handler(
algorithm, ctx, mkldnn_engine, ctx.GetPlace(), out, dout);
auto dst_memory_p = handler.AcquireBackwardSrcMemory(out);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_backward_p->execute(astream,
{{DNNL_ARG_DST, *dst_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
dx->set_mem_desc(diff_src_memory_p->get_desc());
}
template <typename T, dnnl::algorithm algorithm> template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> { struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
...@@ -151,30 +87,6 @@ struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> { ...@@ -151,30 +87,6 @@ struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
const bool approximate = ctx.Attr<bool>("approximate");
if (approximate) {
eltwise_forward<T>(ctx, dnnl::algorithm::eltwise_gelu_tanh);
} else {
eltwise_forward<T>(ctx, dnnl::algorithm::eltwise_gelu_erf);
}
}
};
template <typename T>
struct GeluMKLDNNGradFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
const bool approximate = ctx.Attr<bool>("approximate");
if (approximate) {
eltwise_grad<T>(ctx, dnnl::algorithm::eltwise_gelu_tanh);
} else {
eltwise_grad<T>(ctx, dnnl::algorithm::eltwise_gelu_erf);
}
}
};
template <typename T> template <typename T>
struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> { struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
...@@ -209,6 +121,4 @@ namespace ops = paddle::operators; ...@@ -209,6 +121,4 @@ namespace ops = paddle::operators;
ops::grad_functor<paddle::platform::bfloat16>>); ops::grad_functor<paddle::platform::bfloat16>>);
REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(softplus, SoftplusMKLDNNFunctor); REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(softplus, SoftplusMKLDNNFunctor);
REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(gelu, GeluMKLDNNFunctor);
REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(gelu, GeluMKLDNNGradFunctor);
REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(relu6, Relu6MKLDNNGradFunctor); REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(relu6, Relu6MKLDNNGradFunctor);
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/activation_grad_kernel.h" #include "paddle/phi/kernels/activation_grad_kernel.h"
#include "paddle/phi/kernels/gelu_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_context.h" #include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/backends/onednn/onednn_reuse.h"
...@@ -23,16 +24,6 @@ ...@@ -23,16 +24,6 @@
namespace phi { namespace phi {
#define DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPX(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
DenseTensor* dx) { \
functor_class<T> functor; \
functor(dev_ctx, x, dout, 0, 0, dx); \
}
#define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX( \ #define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX( \
name, functor_class, attr) \ name, functor_class, attr) \
template <typename T, typename Context> \ template <typename T, typename Context> \
...@@ -55,18 +46,6 @@ namespace phi { ...@@ -55,18 +46,6 @@ namespace phi {
functor(dev_ctx, out, dout, 0, 0, dx); \ functor(dev_ctx, out, dout, 0, 0, dx); \
} }
#define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT( \
name, functor_class, attr) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \
const DenseTensor& dout, \
float attr, \
DenseTensor* dx) { \
functor_class<T> functor; \
functor(dev_ctx, out, dout, attr, 0, dx); \
}
template <typename T> template <typename T>
void eltwise_grad(const OneDNNContext& dev_ctx, void eltwise_grad(const OneDNNContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -158,12 +137,14 @@ using AbsOneDNNGradFunctor = ...@@ -158,12 +137,14 @@ using AbsOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_abs>; OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_abs>;
template <typename T> template <typename T>
using ReluOneDNNGradFunctor = using EluOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>; T,
dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;
template <typename T> template <typename T>
using SwishOneDNNGradFunctor = using ExpOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_swish>; T,
dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;
template <typename T> template <typename T>
using HardSwishOneDNNGradFunctor = using HardSwishOneDNNGradFunctor =
...@@ -174,14 +155,21 @@ using MishOneDNNGradFunctor = ...@@ -174,14 +155,21 @@ using MishOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_mish>; OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_mish>;
template <typename T> template <typename T>
using SigmoidOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< using GeluTanhOneDNNGradFunctor =
T, OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_gelu_tanh>;
dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
template <typename T> template <typename T>
using TanhOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< using GeluErfOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_gelu_erf>;
template <typename T>
using ReluOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
template <typename T>
using SigmoidOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T, T,
dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>; dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
template <typename T> template <typename T>
using SqrtOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< using SqrtOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
...@@ -189,22 +177,21 @@ using SqrtOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< ...@@ -189,22 +177,21 @@ using SqrtOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>; dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>;
template <typename T> template <typename T>
using EluOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< using SwishOneDNNGradFunctor =
T, OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_swish>;
dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;
template <typename T> template <typename T>
using ExpOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< using TanhOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T, T,
dnnl::algorithm::eltwise_exp_use_dst_for_bwd>; dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid,
SigmoidOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Abs, AbsOneDNNGradFunctor); DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Abs, AbsOneDNNGradFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluOneDNNGradFunctor); DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluOneDNNGradFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid,
SigmoidOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu,
ReluOneDNNGradFunctor, ReluOneDNNGradFunctor,
...@@ -215,17 +202,6 @@ DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, ...@@ -215,17 +202,6 @@ DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish,
DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish,
SwishOneDNNGradFunctor, SwishOneDNNGradFunctor,
beta); beta);
template <typename T, typename Context>
void HardSwishGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
float threshold,
float scale,
float offset,
DenseTensor* dx) {
HardSwishOneDNNGradFunctor<T> functor;
functor(dev_ctx, x, dout, threshold, 0, dx);
}
template <typename T, typename Context> template <typename T, typename Context>
void EluGradKernel(const Context& dev_ctx, void EluGradKernel(const Context& dev_ctx,
...@@ -238,6 +214,33 @@ void EluGradKernel(const Context& dev_ctx, ...@@ -238,6 +214,33 @@ void EluGradKernel(const Context& dev_ctx,
functor(dev_ctx, out, dout, alpha, 0, dx); functor(dev_ctx, out, dout, alpha, 0, dx);
} }
template <typename T, typename Context>
void GeluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
bool approximate,
DenseTensor* x_grad) {
if (approximate) {
GeluTanhOneDNNGradFunctor<T> functor;
functor(dev_ctx, x, out_grad, 0, 0, x_grad);
} else {
GeluErfOneDNNGradFunctor<T> functor;
functor(dev_ctx, x, out_grad, 0, 0, x_grad);
}
}
template <typename T, typename Context>
void HardSwishGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
float threshold,
float scale,
float offset,
DenseTensor* dx) {
HardSwishOneDNNGradFunctor<T> functor;
functor(dev_ctx, x, dout, threshold, 0, dx);
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(relu_grad, PD_REGISTER_KERNEL(relu_grad,
...@@ -254,6 +257,7 @@ PD_REGISTER_KERNEL(relu_grad, ...@@ -254,6 +257,7 @@ PD_REGISTER_KERNEL(relu_grad,
PD_REGISTER_ACTIVATION_GRAD_KERNEL(abs_grad, AbsGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(abs_grad, AbsGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(gelu_grad, GeluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/activation_kernel.h" #include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/gelu_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_context.h" #include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/backends/onednn/onednn_reuse.h"
...@@ -91,16 +92,18 @@ template <typename T> ...@@ -91,16 +92,18 @@ template <typename T>
using AbsOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_abs>; using AbsOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_abs>;
template <typename T> template <typename T>
using ReluOneDNNFunctor = using EluOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>;
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_relu>;
template <typename T> template <typename T>
using Relu6OneDNNFunctor = using ExpOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>;
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
template <typename T> template <typename T>
using SwishOneDNNFunctor = using GeluTanhOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_swish>; OneDNNActivationFunc<T, dnnl::algorithm::eltwise_gelu_tanh>;
template <typename T>
using GeluErfOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_gelu_erf>;
template <typename T> template <typename T>
using HardSwishOneDNNFunctor = using HardSwishOneDNNFunctor =
...@@ -111,40 +114,46 @@ using MishOneDNNFunctor = ...@@ -111,40 +114,46 @@ using MishOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_mish>; OneDNNActivationFunc<T, dnnl::algorithm::eltwise_mish>;
template <typename T> template <typename T>
using SigmoidOneDNNFunctor = using ReluOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>; OneDNNActivationFunc<T, dnnl::algorithm::eltwise_relu>;
template <typename T> template <typename T>
using TanhOneDNNFunctor = using Relu6OneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_tanh>; OneDNNActivationFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
template <typename T> template <typename T>
using SqrtOneDNNFunctor = using RoundOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_sqrt>; OneDNNActivationFunc<T, dnnl::algorithm::eltwise_round>;
template <typename T> template <typename T>
using EluOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>; using SigmoidOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>;
template <typename T> template <typename T>
using ExpOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>; using SqrtOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_sqrt>;
template <typename T> template <typename T>
using RoundOneDNNFunctor = using SwishOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_round>; OneDNNActivationFunc<T, dnnl::algorithm::eltwise_swish>;
template <typename T>
using TanhOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_tanh>;
DEFINE_ONEDNN_ACTIVATION_KERNEL(Abs, AbsOneDNNFunctor) DEFINE_ONEDNN_ACTIVATION_KERNEL(Abs, AbsOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Relu, ReluOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Tanh, TanhOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Exp, ExpOneDNNFunctor) DEFINE_ONEDNN_ACTIVATION_KERNEL(Exp, ExpOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Sqrt, SqrtOneDNNFunctor) DEFINE_ONEDNN_ACTIVATION_KERNEL(Relu, ReluOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Sigmoid, SigmoidOneDNNFunctor) DEFINE_ONEDNN_ACTIVATION_KERNEL(Sigmoid, SigmoidOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Sqrt, SqrtOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Tanh, TanhOneDNNFunctor)
// round eltwise primitive doesn't support BF16, nor does it support grad // round eltwise primitive doesn't support BF16, nor does it support grad
DEFINE_ONEDNN_ACTIVATION_KERNEL(Round, RoundOneDNNFunctor) DEFINE_ONEDNN_ACTIVATION_KERNEL(Round, RoundOneDNNFunctor)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, ReluOneDNNFunctor, alpha) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, ReluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishOneDNNFunctor, threshold) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishOneDNNFunctor, threshold)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Relu6, Relu6OneDNNFunctor, threshold) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Relu6, Relu6OneDNNFunctor, threshold)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishOneDNNFunctor, beta) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishOneDNNFunctor, beta)
...@@ -159,6 +168,20 @@ void HardSwishKernel(const Context& dev_ctx, ...@@ -159,6 +168,20 @@ void HardSwishKernel(const Context& dev_ctx,
functor(dev_ctx, x, threshold, 0, out); functor(dev_ctx, x, threshold, 0, out);
} }
template <typename T, typename Context>
void GeluKernel(const Context& dev_ctx,
const DenseTensor& x,
bool approximate,
DenseTensor* out) {
if (approximate) {
GeluTanhOneDNNFunctor<T> functor;
functor(dev_ctx, x, 0, 0, out);
} else {
GeluErfOneDNNFunctor<T> functor;
functor(dev_ctx, x, 0, 0, out);
}
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {} PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {}
...@@ -170,6 +193,7 @@ PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {} ...@@ -170,6 +193,7 @@ PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {}
PD_REGISTER_ACTIVATION_KERNEL(abs, AbsKernel) PD_REGISTER_ACTIVATION_KERNEL(abs, AbsKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel)
PD_REGISTER_ACTIVATION_KERNEL(gelu, GeluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册