未验证 提交 cee2b12d 编写于 作者: S Sławomir Siwek 提交者: GitHub

[PHI] relu6_grad kernel (#46501)

* Relu6

* remove fluid handler

* add individual kernel signature

* coding style

* replace bounded_relu with clip

* whitespace

* code style
上级 c7da8602
......@@ -42,51 +42,6 @@ class MKLDNNActivationKernel
}
};
template <typename Functor>
class MKLDNNActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
Functor functor;
functor(ctx);
}
};
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X");
const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
platform::ActivationMKLDNNHandler<T> handler(
algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x, dout);
auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_backward_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
dx->set_mem_desc(diff_src_memory_p->get_desc());
}
template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
eltwise_grad<T>(ctx, algorithm);
}
};
template <typename T>
struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
......@@ -94,10 +49,6 @@ struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
using Relu6MKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
} // namespace operators
} // namespace paddle
......@@ -111,14 +62,4 @@ namespace ops = paddle::operators;
ops::MKLDNNActivationKernel<ops::functor<float>>, \
ops::MKLDNNActivationKernel<ops::functor<paddle::platform::bfloat16>>);
#define REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(act_type, grad_functor) \
REGISTER_OP_KERNEL( \
act_type##_grad, \
MKLDNN, \
::paddle::platform::CPUPlace, \
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>, \
ops::MKLDNNActivationGradKernel< \
ops::grad_functor<paddle::platform::bfloat16>>);
REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(softplus, SoftplusMKLDNNFunctor);
REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(relu6, Relu6MKLDNNGradFunctor);
......@@ -294,103 +294,6 @@ class MatMulV2MKLDNNHandler
}
};
template <typename T>
class ActivationMKLDNNHandler
: public MKLDNNHandlerNoCachingT<T,
dnnl::eltwise_forward,
dnnl::eltwise_backward> {
public:
ActivationMKLDNNHandler(dnnl::algorithm algorithm,
const framework::ExecutionContext& ctx,
const dnnl::engine engine,
Place cpu_place,
const framework::Tensor* x)
: platform::MKLDNNHandlerNoCachingT<T,
dnnl::eltwise_forward,
dnnl::eltwise_backward>(engine,
cpu_place) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
if (ctx.Type() == "scale") {
bool bias_after_scale = ctx.Attr<bool>("bias_after_scale");
auto* scale_tensor = ctx.Input<Tensor>("ScaleTensor");
alpha = (scale_tensor == nullptr)
? ctx.Attr<float>("scale")
: static_cast<float>(*(scale_tensor->data<T>()));
beta = ctx.Attr<float>("bias");
// if bias_after_scale == true
// out = scale*X + bias
// else
// out = scale*(X + bias) = scale*X + scale*bias
if (!bias_after_scale) {
beta *= alpha;
}
} else if (ctx.Type() == "clip") {
alpha = ctx.HasInput("Min") ? ctx.Input<Tensor>("Min")->data<float>()[0]
: ctx.Attr<float>("min");
beta = ctx.HasInput("Max") ? ctx.Input<Tensor>("Max")->data<float>()[0]
: ctx.Attr<float>("max");
} else {
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == dnnl::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
}
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
algorithm,
x->mem_desc(),
alpha,
beta);
}
ActivationMKLDNNHandler(dnnl::algorithm algorithm,
const framework::ExecutionContext& ctx,
const dnnl::engine engine,
Place cpu_place,
const framework::Tensor* x,
const Tensor* dout)
: platform::MKLDNNHandlerNoCachingT<T,
dnnl::eltwise_forward,
dnnl::eltwise_backward>(engine,
cpu_place) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == dnnl::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
if (ctx.Type() == "clip_grad") {
alpha = ctx.HasInput("Min") ? ctx.Input<Tensor>("Min")->data<float>()[0]
: ctx.Attr<float>("min");
beta = ctx.HasInput("Max") ? ctx.Input<Tensor>("Max")->data<float>()[0]
: ctx.Attr<float>("max");
}
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
algorithm,
x->mem_desc(),
alpha,
beta);
this->AcquireBackwardPrimitiveDescriptor(
algorithm, dout->mem_desc(), x->mem_desc(), alpha, beta);
}
std::shared_ptr<dnnl::memory> AcquireBackwardSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(),
to_void_cast<T>(input_data));
}
};
static std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) {
std::unordered_map<std::string, std::string> attr_map;
......
......@@ -166,6 +166,11 @@ template <typename T>
using ReluOneDNNGradFunctor =
OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
template <typename T>
using Relu6OneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_clip_v2_use_dst_for_bwd>;
template <typename T>
using SigmoidOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
T,
......@@ -241,6 +246,16 @@ void HardSwishGradKernel(const Context& dev_ctx,
functor(dev_ctx, x, dout, threshold, 0, dx);
}
template <typename T, typename Context>
void Relu6GradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
float threshold,
DenseTensor* dx) {
Relu6OneDNNGradUseOutFunctor<T> functor;
functor(dev_ctx, out, dout, 0, threshold, dx);
}
} // namespace phi
PD_REGISTER_KERNEL(relu_grad,
......@@ -261,6 +276,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(gelu_grad, GeluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
......
......@@ -119,7 +119,7 @@ using ReluOneDNNFunctor =
template <typename T>
using Relu6OneDNNFunctor =
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
OneDNNActivationFunc<T, dnnl::algorithm::eltwise_clip_v2>;
template <typename T>
using RoundOneDNNFunctor =
......@@ -154,7 +154,6 @@ DEFINE_ONEDNN_ACTIVATION_KERNEL(Round, RoundOneDNNFunctor)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, ReluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishOneDNNFunctor, threshold)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Relu6, Relu6OneDNNFunctor, threshold)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishOneDNNFunctor, beta)
template <typename T, typename Context>
......@@ -182,6 +181,15 @@ void GeluKernel(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void Relu6Kernel(const Context& dev_ctx,
const DenseTensor& x,
float threshold,
DenseTensor* out) {
Relu6OneDNNFunctor<T> functor;
functor(dev_ctx, x, 0, threshold, out);
}
} // namespace phi
PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册