未验证 提交 718183f1 编写于 作者: J jakpiase 提交者: GitHub

Added exp FP32 FWD/BWD oneDNN kernel and optimized other oneDNN grad kernels (#38624)

* added exp activation and use_dst_for_bwd kernels

* CI RERUN

* minor change
上级 36a102f8
......@@ -83,9 +83,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out");
auto *out = ctx.Output<Tensor>("Out");
bool is_inplaced = x->IsSharedBufferWith(*y);
bool is_inplaced = x->IsSharedBufferWith(*out);
platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
ctx.GetPlace(), x);
......@@ -94,9 +94,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
if (is_inplaced) {
dst_memory_p = src_memory_p;
y->mutable_data<T>(ctx.GetPlace());
out->mutable_data<T>(ctx.GetPlace());
} else {
dst_memory_p = handler.AcquireDstMemory(y);
dst_memory_p = handler.AcquireDstMemory(out);
}
auto activation_p = handler.AcquireForwardPrimitive();
......@@ -105,8 +105,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}});
astream.wait();
y->set_layout(DataLayout::kMKLDNN);
y->set_format(GetMKLDNNFormat(*dst_memory_p));
out->set_layout(DataLayout::kMKLDNN);
out->set_format(GetMKLDNNFormat(*dst_memory_p));
}
template <typename T>
......@@ -116,15 +116,15 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X");
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
ctx.GetPlace(), x, diff_y);
ctx.GetPlace(), x, dout);
auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
......@@ -134,8 +134,37 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p));
dx->set_layout(DataLayout::kMKLDNN);
dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
}
template <typename T>
void eltwise_grad_use_out(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *out = ctx.Input<Tensor>("Out");
const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
ctx.GetPlace(), out, dout);
auto dst_memory_p = handler.AcquireBackwardSrcMemory(out);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_backward_p->execute(astream,
{{DNNL_ARG_DST, *dst_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
dx->set_layout(DataLayout::kMKLDNN);
dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
}
template <typename T, dnnl::algorithm algorithm>
......@@ -152,6 +181,13 @@ struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
}
};
template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationGradUseOutFunc : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
eltwise_grad_use_out<T>(ctx, algorithm);
}
};
template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
......@@ -217,6 +253,9 @@ using AbsMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_abs>;
template <typename T>
using EluMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>;
template <typename T>
using ExpMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>;
template <typename T>
using ReluMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
......@@ -234,24 +273,29 @@ using HardSwishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_hardswish>;
template <typename T>
using SigmoidMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_logistic>;
using SigmoidMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T, dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
template <typename T>
using TanhMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_tanh>;
using TanhMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T, dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;
template <typename T>
using SqrtMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_sqrt>;
using SqrtMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T, dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>;
template <typename T>
using AbsMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_abs>;
template <typename T>
using EluMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_elu>;
using EluMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T, dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;
template <typename T>
using ExpMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T, dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;
} // namespace operators
} // namespace paddle
......@@ -281,9 +325,10 @@ namespace ops = paddle::operators;
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor); \
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); \
__macro(elu, EluMKLDNNFunctor, EluMKLDNNGradFunctor);
__macro(elu, EluMKLDNNFunctor, EluMKLDNNGradUseOutFunctor); \
__macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor);
FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
......@@ -291,9 +336,9 @@ REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor,
GeluMKLDNNGradFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor,
SigmoidMKLDNNGradFunctor);
SigmoidMKLDNNGradUseOutFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor,
SqrtMKLDNNGradFunctor);
SqrtMKLDNNGradUseOutFunctor);
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
......
......@@ -349,6 +349,16 @@ class TestMKLDNNEluCustomAlpha(TestMKLDNNEluDefaultAlpha):
self.alpha = 2.5
class TestMKLDNNExpOp(TestActivation):
def setUp(self):
self.op_type = "exp"
x = np.random.random((5, 5, 4)).astype("float32")
self.inputs = {'X': x}
self.attrs = {'use_mkldnn': True}
self.outputs = {'Out': np.exp(x)}
# Check if primitives already exist in backward
class TestMKLDNNAbsPrimitivesAlreadyExist(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册