未验证 提交 2bcbf8b0 编写于 作者: S Sławomir Siwek 提交者: GitHub

[cherry-pick] [PHI] relu6_grad kernel (#46501) (#46862)

* [PHI] Migrate gelu kernels (#45596)

* gaussian random

* mkldnn to onednn renaming

* fix merge conflicts

* remove fluid code

* onednn renaming

* gelu fwd

* sort activations

* gelu gradient

* remove unused macros

* merge conflicts

* fix merge conflicts

* remove extra contraint from gelu op

* [PHI] relu6_grad kernel (#46501)

* Relu6

* remove fluid handler

* add individual kernel signature

* coding style

* replace bounded_relu with clip

* whitespace

* code style
上级 7b3837e6
......@@ -510,7 +510,7 @@ function(op_library TARGET)
if(WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
# Append first implemented MKLDNN activation operator
if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(gelu, MKLDNN);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
......
......@@ -38,7 +38,7 @@ USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN);
USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP_ITSELF(gelu);
USE_OP_DEVICE_KERNEL(gelu, MKLDNN);
PD_DECLARE_KERNEL(gelu, OneDNN, ALL_LAYOUT);
PD_DECLARE_ARG_MAPPING_FN(gelu);
namespace paddle {
......
......@@ -80,11 +80,11 @@ class GeluGradOp : public framework::OperatorWithKernel {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
......
......@@ -42,139 +42,6 @@ class MKLDNNActivationKernel
}
};
template <typename Functor>
class MKLDNNActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
Functor functor;
functor(ctx);
}
};
template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL eletwise_forward must use CPUPlace"));
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X");
auto *out = ctx.Output<Tensor>("Out");
bool is_inplaced = x->IsSharedBufferWith(*out);
platform::ActivationMKLDNNHandler<T> handler(
algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x);
auto src_memory_p = handler.AcquireSrcMemory(x);
std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
if (is_inplaced) {
dst_memory_p = src_memory_p;
out->mutable_data<T>(ctx.GetPlace());
} else {
dst_memory_p = handler.AcquireDstMemory(out);
}
auto activation_p = handler.AcquireForwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_p->execute(
astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}});
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc());
}
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X");
const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
platform::ActivationMKLDNNHandler<T> handler(
algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x, dout);
auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_backward_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
dx->set_mem_desc(diff_src_memory_p->get_desc());
}
template <typename T>
void eltwise_grad_use_out(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *out = ctx.Input<Tensor>("Out");
const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
platform::ActivationMKLDNNHandler<T> handler(
algorithm, ctx, mkldnn_engine, ctx.GetPlace(), out, dout);
auto dst_memory_p = handler.AcquireBackwardSrcMemory(out);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_backward_p->execute(astream,
{{DNNL_ARG_DST, *dst_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
dx->set_mem_desc(diff_src_memory_p->get_desc());
}
template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
eltwise_grad<T>(ctx, algorithm);
}
};
template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
const bool approximate = ctx.Attr<bool>("approximate");
if (approximate) {
eltwise_forward<T>(ctx, dnnl::algorithm::eltwise_gelu_tanh);
} else {
eltwise_forward<T>(ctx, dnnl::algorithm::eltwise_gelu_erf);
}
}
};
template <typename T>
struct GeluMKLDNNGradFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
const bool approximate = ctx.Attr<bool>("approximate");
if (approximate) {
eltwise_grad<T>(ctx, dnnl::algorithm::eltwise_gelu_tanh);
} else {
eltwise_grad<T>(ctx, dnnl::algorithm::eltwise_gelu_erf);
}
}
};
template <typename T>
struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
......@@ -182,10 +49,6 @@ struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
using Relu6MKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
} // namespace operators
} // namespace paddle
......@@ -199,16 +62,4 @@ namespace ops = paddle::operators;
ops::MKLDNNActivationKernel<ops::functor<float>>, \
ops::MKLDNNActivationKernel<ops::functor<paddle::platform::bfloat16>>);
#define REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(act_type, grad_functor) \
REGISTER_OP_KERNEL( \
act_type##_grad, \
MKLDNN, \
::paddle::platform::CPUPlace, \
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>, \
ops::MKLDNNActivationGradKernel< \
ops::grad_functor<paddle::platform::bfloat16>>);
REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(softplus, SoftplusMKLDNNFunctor);
REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(gelu, GeluMKLDNNFunctor);
REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(gelu, GeluMKLDNNGradFunctor);
REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(relu6, Relu6MKLDNNGradFunctor);
......@@ -293,103 +293,6 @@ class MatMulV2MKLDNNHandler
}
};
template <typename T>
class ActivationMKLDNNHandler
: public MKLDNNHandlerNoCachingT<T,
dnnl::eltwise_forward,
dnnl::eltwise_backward> {
public:
ActivationMKLDNNHandler(dnnl::algorithm algorithm,
const framework::ExecutionContext& ctx,
const dnnl::engine engine,
Place cpu_place,
const framework::Tensor* x)
: platform::MKLDNNHandlerNoCachingT<T,
dnnl::eltwise_forward,
dnnl::eltwise_backward>(engine,
cpu_place) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
if (ctx.Type() == "scale") {
bool bias_after_scale = ctx.Attr<bool>("bias_after_scale");
auto* scale_tensor = ctx.Input<Tensor>("ScaleTensor");
alpha = (scale_tensor == nullptr)
? ctx.Attr<float>("scale")
: static_cast<float>(*(scale_tensor->data<T>()));
beta = ctx.Attr<float>("bias");
// if bias_after_scale == true
// out = scale*X + bias
// else
// out = scale*(X + bias) = scale*X + scale*bias
if (!bias_after_scale) {
beta *= alpha;
}
} else if (ctx.Type() == "clip") {
alpha = ctx.HasInput("Min") ? ctx.Input<Tensor>("Min")->data<float>()[0]
: ctx.Attr<float>("min");
beta = ctx.HasInput("Max") ? ctx.Input<Tensor>("Max")->data<float>()[0]
: ctx.Attr<float>("max");
} else {
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == dnnl::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
}
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
algorithm,
x->mem_desc(),
alpha,
beta);
}
ActivationMKLDNNHandler(dnnl::algorithm algorithm,
const framework::ExecutionContext& ctx,
const dnnl::engine engine,
Place cpu_place,
const framework::Tensor* x,
const Tensor* dout)
: platform::MKLDNNHandlerNoCachingT<T,
dnnl::eltwise_forward,
dnnl::eltwise_backward>(engine,
cpu_place) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == dnnl::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
if (ctx.Type() == "clip_grad") {
alpha = ctx.HasInput("Min") ? ctx.Input<Tensor>("Min")->data<float>()[0]
: ctx.Attr<float>("min");
beta = ctx.HasInput("Max") ? ctx.Input<Tensor>("Max")->data<float>()[0]
: ctx.Attr<float>("max");
}
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
algorithm,
x->mem_desc(),
alpha,
beta);
this->AcquireBackwardPrimitiveDescriptor(
algorithm, dout->mem_desc(), x->mem_desc(), alpha, beta);
}
std::shared_ptr<dnnl::memory> AcquireBackwardSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(),
to_void_cast<T>(input_data));
}
};
static std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) {
std::unordered_map<std::string, std::string> attr_map;
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/activation_grad_kernel.h"
#include "paddle/phi/kernels/gelu_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
......@@ -23,16 +24,6 @@
namespace phi {
#define DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPX(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
DenseTensor* dx) { \
functor_class<T> functor; \
functor(dev_ctx, x, dout, 0, 0, dx); \
}
#define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX( \
name, functor_class, attr) \
template <typename T, typename Context> \
......@@ -55,18 +46,6 @@ namespace phi {
functor(dev_ctx, out, dout, 0, 0, dx); \
}
#define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT( \
name, functor_class, attr) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \
const DenseTensor& dout, \
float attr, \
DenseTensor* dx) { \
functor_class<T> functor; \
functor(dev_ctx, out, dout, attr, 0, dx); \
}
template <typename T>
void eltwise_grad(const OneDNNContext& dev_ctx,
const DenseTensor& x,
......@@ -158,12 +137,14 @@ using AbsOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_abs>;
template <typename T>
using ReluOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
using EluOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;
template <typename T>
using SwishOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_swish>;
using ExpOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;
template <typename T>
using HardSwishOneDNNGradFunctor =
......@@ -174,14 +155,26 @@ using MishOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_mish>;
template <typename T>
using SigmoidOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
using GeluTanhOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_gelu_tanh>;
template <typename T>
using GeluErfOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_gelu_erf>;
template <typename T>
using ReluOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
template <typename T>
using Relu6OneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
dnnl::algorithm::eltwise_clip_v2_use_dst_for_bwd>;
template <typename T>
using TanhOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
using SigmoidOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;
dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
template <typename T>
using SqrtOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
......@@ -189,22 +182,21 @@ using SqrtOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>;
template <typename T>
using EluOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;
using SwishOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_swish>;
template <typename T>
using ExpOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
using TanhOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;
dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid,
SigmoidOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Abs, AbsOneDNNGradFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluOneDNNGradFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid,
SigmoidOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu,
ReluOneDNNGradFunctor,
......@@ -215,6 +207,33 @@ DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish,
DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish,
SwishOneDNNGradFunctor,
beta);
template <typename T, typename Context>
void EluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& dout,
float alpha,
DenseTensor* dx) {
EluOneDNNGradUseOutFunctor<T> functor;
functor(dev_ctx, out, dout, alpha, 0, dx);
}
template <typename T, typename Context>
void GeluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
bool approximate,
DenseTensor* x_grad) {
if (approximate) {
GeluTanhOneDNNGradFunctor<T> functor;
functor(dev_ctx, x, out_grad, 0, 0, x_grad);
} else {
GeluErfOneDNNGradFunctor<T> functor;
functor(dev_ctx, x, out_grad, 0, 0, x_grad);
}
}
template <typename T, typename Context>
void HardSwishGradKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -228,14 +247,13 @@ void HardSwishGradKernel(const Context& dev_ctx,
}
template <typename T, typename Context>
void EluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
void Relu6GradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
float alpha,
float threshold,
DenseTensor* dx) {
EluOneDNNGradUseOutFunctor<T> functor;
functor(dev_ctx, out, dout, alpha, 0, dx);
Relu6OneDNNGradUseOutFunctor<T> functor;
functor(dev_ctx, out, dout, 0, threshold, dx);
}
} // namespace phi
......@@ -254,9 +272,11 @@ PD_REGISTER_KERNEL(relu_grad,
PD_REGISTER_ACTIVATION_GRAD_KERNEL(abs_grad, AbsGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(gelu_grad, GeluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/gelu_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
......@@ -91,16 +92,18 @@ template <typename T>
using AbsOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_abs>;
template <typename T>
using ReluOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_relu>;
using EluOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>;
template <typename T>
using Relu6OneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
using ExpOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>;
template <typename T>
using SwishOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_swish>;
using GeluTanhOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_gelu_tanh>;
template <typename T>
using GeluErfOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_gelu_erf>;
template <typename T>
using HardSwishOneDNNFunctor =
......@@ -111,41 +114,46 @@ using MishOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_mish>;
template <typename T>
using SigmoidOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>;
using ReluOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_relu>;
template <typename T>
using TanhOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_tanh>;
using Relu6OneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_clip_v2>;
template <typename T>
using SqrtOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_sqrt>;
using RoundOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_round>;
template <typename T>
using EluOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>;
using SigmoidOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>;
template <typename T>
using ExpOneDNNFunctor = OneDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>;
using SqrtOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_sqrt>;
template <typename T>
using RoundOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_round>;
using SwishOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_swish>;
template <typename T>
using TanhOneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_tanh>;
DEFINE_ONEDNN_ACTIVATION_KERNEL(Abs, AbsOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Relu, ReluOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Tanh, TanhOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Exp, ExpOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Sqrt, SqrtOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Relu, ReluOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Sigmoid, SigmoidOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Sqrt, SqrtOneDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Tanh, TanhOneDNNFunctor)
// round eltwise primitive doesn't support BF16, nor does it support grad
DEFINE_ONEDNN_ACTIVATION_KERNEL(Round, RoundOneDNNFunctor)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, ReluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishOneDNNFunctor, threshold)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Relu6, Relu6OneDNNFunctor, threshold)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishOneDNNFunctor, beta)
template <typename T, typename Context>
......@@ -159,6 +167,29 @@ void HardSwishKernel(const Context& dev_ctx,
functor(dev_ctx, x, threshold, 0, out);
}
template <typename T, typename Context>
void GeluKernel(const Context& dev_ctx,
const DenseTensor& x,
bool approximate,
DenseTensor* out) {
if (approximate) {
GeluTanhOneDNNFunctor<T> functor;
functor(dev_ctx, x, 0, 0, out);
} else {
GeluErfOneDNNFunctor<T> functor;
functor(dev_ctx, x, 0, 0, out);
}
}
template <typename T, typename Context>
void Relu6Kernel(const Context& dev_ctx,
const DenseTensor& x,
float threshold,
DenseTensor* out) {
Relu6OneDNNFunctor<T> functor;
functor(dev_ctx, x, 0, threshold, out);
}
} // namespace phi
PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {}
......@@ -170,6 +201,7 @@ PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {}
PD_REGISTER_ACTIVATION_KERNEL(abs, AbsKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel)
PD_REGISTER_ACTIVATION_KERNEL(gelu, GeluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册