未验证 提交 718183f1 编写于 作者: J jakpiase 提交者: GitHub

Added exp FP32 FWD/BWD oneDNN kernel and optimized other oneDNN grad kernels (#38624)

* added exp activation and use_dst_for_bwd kernels

* CI RERUN

* minor change
上级 36a102f8
...@@ -83,9 +83,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -83,9 +83,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
const auto &mkldnn_engine = dev_ctx.GetEngine(); const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out"); auto *out = ctx.Output<Tensor>("Out");
bool is_inplaced = x->IsSharedBufferWith(*y); bool is_inplaced = x->IsSharedBufferWith(*out);
platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine, platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
ctx.GetPlace(), x); ctx.GetPlace(), x);
...@@ -94,9 +94,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -94,9 +94,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
std::shared_ptr<dnnl::memory> dst_memory_p = nullptr; std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
if (is_inplaced) { if (is_inplaced) {
dst_memory_p = src_memory_p; dst_memory_p = src_memory_p;
y->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
} else { } else {
dst_memory_p = handler.AcquireDstMemory(y); dst_memory_p = handler.AcquireDstMemory(out);
} }
auto activation_p = handler.AcquireForwardPrimitive(); auto activation_p = handler.AcquireForwardPrimitive();
...@@ -105,8 +105,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -105,8 +105,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}}); astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}});
astream.wait(); astream.wait();
y->set_layout(DataLayout::kMKLDNN); out->set_layout(DataLayout::kMKLDNN);
y->set_format(GetMKLDNNFormat(*dst_memory_p)); out->set_format(GetMKLDNNFormat(*dst_memory_p));
} }
template <typename T> template <typename T>
...@@ -116,15 +116,15 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -116,15 +116,15 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
const auto &mkldnn_engine = dev_ctx.GetEngine(); const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine, platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
ctx.GetPlace(), x, diff_y); ctx.GetPlace(), x, dout);
auto src_memory_p = handler.AcquireBackwardSrcMemory(x); auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y); auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x); auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive(); auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
...@@ -134,8 +134,37 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -134,8 +134,37 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait(); astream.wait();
diff_x->set_layout(DataLayout::kMKLDNN); dx->set_layout(DataLayout::kMKLDNN);
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p)); dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
}
template <typename T>
void eltwise_grad_use_out(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *out = ctx.Input<Tensor>("Out");
const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
ctx.GetPlace(), out, dout);
auto dst_memory_p = handler.AcquireBackwardSrcMemory(out);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_backward_p->execute(astream,
{{DNNL_ARG_DST, *dst_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
dx->set_layout(DataLayout::kMKLDNN);
dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
} }
template <typename T, dnnl::algorithm algorithm> template <typename T, dnnl::algorithm algorithm>
...@@ -152,6 +181,13 @@ struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> { ...@@ -152,6 +181,13 @@ struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
} }
}; };
template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationGradUseOutFunc : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
eltwise_grad_use_out<T>(ctx, algorithm);
}
};
template <typename T> template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> { struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
...@@ -217,6 +253,9 @@ using AbsMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_abs>; ...@@ -217,6 +253,9 @@ using AbsMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_abs>;
template <typename T> template <typename T>
using EluMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>; using EluMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>;
template <typename T>
using ExpMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>;
template <typename T> template <typename T>
using ReluMKLDNNGradFunctor = using ReluMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>; MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
...@@ -234,24 +273,29 @@ using HardSwishMKLDNNGradFunctor = ...@@ -234,24 +273,29 @@ using HardSwishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_hardswish>; MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_hardswish>;
template <typename T> template <typename T>
using SigmoidMKLDNNGradFunctor = using SigmoidMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_logistic>; T, dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
template <typename T> template <typename T>
using TanhMKLDNNGradFunctor = using TanhMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_tanh>; T, dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;
template <typename T> template <typename T>
using SqrtMKLDNNGradFunctor = using SqrtMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_sqrt>; T, dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>;
template <typename T> template <typename T>
using AbsMKLDNNGradFunctor = using AbsMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_abs>; MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_abs>;
template <typename T> template <typename T>
using EluMKLDNNGradFunctor = using EluMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_elu>; T, dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;
template <typename T>
using ExpMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T, dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -281,9 +325,10 @@ namespace ops = paddle::operators; ...@@ -281,9 +325,10 @@ namespace ops = paddle::operators;
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \ __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \ __macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \ __macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \ __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor); \
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); \ __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); \
__macro(elu, EluMKLDNNFunctor, EluMKLDNNGradFunctor); __macro(elu, EluMKLDNNFunctor, EluMKLDNNGradUseOutFunctor); \
__macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor);
FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL); FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor, REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
...@@ -291,9 +336,9 @@ REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor, ...@@ -291,9 +336,9 @@ REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor, REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor,
GeluMKLDNNGradFunctor); GeluMKLDNNGradFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor, REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor,
SigmoidMKLDNNGradFunctor); SigmoidMKLDNNGradUseOutFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor, REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor,
SqrtMKLDNNGradFunctor); SqrtMKLDNNGradUseOutFunctor);
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL( REGISTER_OP_KERNEL(
......
...@@ -349,6 +349,16 @@ class TestMKLDNNEluCustomAlpha(TestMKLDNNEluDefaultAlpha): ...@@ -349,6 +349,16 @@ class TestMKLDNNEluCustomAlpha(TestMKLDNNEluDefaultAlpha):
self.alpha = 2.5 self.alpha = 2.5
class TestMKLDNNExpOp(TestActivation):
def setUp(self):
self.op_type = "exp"
x = np.random.random((5, 5, 4)).astype("float32")
self.inputs = {'X': x}
self.attrs = {'use_mkldnn': True}
self.outputs = {'Out': np.exp(x)}
# Check if primitives already exist in backward # Check if primitives already exist in backward
class TestMKLDNNAbsPrimitivesAlreadyExist(unittest.TestCase): class TestMKLDNNAbsPrimitivesAlreadyExist(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册