diff --git a/paddle/fluid/operators/activation_mkldnn_op.cc b/paddle/fluid/operators/activation_mkldnn_op.cc index ab7c61227114fe7a0ce2ff2515dd560706058b64..fcc06a709372d7e1b94e4367532df8ff271c17e6 100644 --- a/paddle/fluid/operators/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/activation_mkldnn_op.cc @@ -52,9 +52,11 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, mkldnn::memory::format::nchw); // create memory primitives - auto src_memory = + auto src_memory = std::make_shared( mkldnn::memory({data_md, mkldnn_engine}, - static_cast(const_cast(src_data))); + static_cast(const_cast(src_data)))); + // save source memory to device context to be referred in backward path + dev_ctx.SetBlob("InputX@eltwise_pd", src_memory); auto dst_memory = mkldnn::memory({data_md, mkldnn_engine}, static_cast(const_cast(dst_data))); @@ -69,7 +71,7 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, forward_desc, mkldnn_engine); dev_ctx.SetBlob(key_eltwise_pd, forward_pd); - auto eltwise = mkldnn::eltwise_forward(*forward_pd, src_memory, dst_memory); + auto eltwise = mkldnn::eltwise_forward(*forward_pd, *src_memory, dst_memory); // push primitive to stream and wait until it's executed std::vector pipeline = {eltwise}; @@ -83,8 +85,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, const auto &mkldnn_engine = dev_ctx.GetEngine(); // get buffers - const auto *x = ctx.template Input("X"); - const auto *src = x->template data(); + const auto *x = ctx.template Input("Out"); auto *dout = ctx.template Input(framework::GradVarName("Out")); const auto *diff_dst = dout->template data(); @@ -103,9 +104,11 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, mkldnn::memory::format::nchw); + // retrieve source memory from device context + const std::shared_ptr src_memory = dev_ctx.GetBlob("InputX@eltwise_pd"); + auto *p_src_memory = static_cast(src_memory.get()); + // create memory primitives - auto src_memory = mkldnn::memory( - {data_md, mkldnn_engine}, static_cast(const_cast(src))); auto diff_src_memory = mkldnn::memory({data_md, mkldnn_engine}, static_cast(const_cast(diff_src))); @@ -128,8 +131,8 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc( backward_desc, mkldnn_engine, *p_forward_pd); - auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, src_memory, - diff_dst_memory, diff_src_memory); + auto eltwise_bwd = mkldnn::eltwise_backward( + eltwise_bwd_prim_desc, *p_src_memory, diff_dst_memory, diff_src_memory); // push primitive to stream and wait until it's executed std::vector pipeline = {eltwise_bwd}; diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 55482abdf09516077a94ca99140ae7961f0915aa..6f7a965bcf308694d441c39d165ac12ab04201c3 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -458,21 +458,22 @@ namespace ops = paddle::operators; #define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \ __macro(Sigmoid, sigmoid); \ - __macro(Relu, relu); \ __macro(Exp, exp); \ - __macro(Tanh, tanh); \ __macro(Ceil, ceil); \ __macro(Floor, floor); \ - __macro(Sqrt, sqrt); \ __macro(SoftRelu, soft_relu); \ __macro(Relu6, relu6); \ __macro(Reciprocal, reciprocal); \ __macro(HardSigmoid, hard_sigmoid); +#define FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(__macro) \ + __macro(Relu, relu); \ + __macro(Tanh, tanh); \ + __macro(Sqrt, sqrt); + #define FOR_EACH_OP_FUNCTOR(__macro) \ __macro(LogSigmoid, logsigmoid); \ __macro(SoftShrink, softshrink); \ - __macro(Abs, abs); \ __macro(Cos, cos); \ __macro(Sin, sin); \ __macro(Round, round); \ @@ -490,18 +491,32 @@ namespace ops = paddle::operators; __macro(Swish, swish); \ __macro(ThresholdedRelu, thresholded_relu); +#define FOR_EACH_MKLDNN_OP_FUNCTOR(__macro) __macro(Abs, abs); + #define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ ::paddle::operators::OP_NAME##OpMaker, \ ::paddle::operators::OP_NAME##GradMaker); \ REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) +#define REGISTER_INPLACE_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \ + REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \ + ::paddle::operators::OP_NAME##OpMaker, \ + ::paddle::operators::OP_NAME##GradMaker); \ + REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad) + #define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ ::paddle::operators::OP_NAME##OpMaker, \ ::paddle::framework::DefaultGradOpDescMaker); \ REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) +#define REGISTER_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \ + REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \ + ::paddle::operators::OP_NAME##OpMaker, \ + ::paddle::framework::DefaultGradOpDescMaker); \ + REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad) + #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ REGISTER_OP_CPU_KERNEL( \ act_type, ops::ActivationKernel>); FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP); +FOR_EACH_MKLDNN_OP_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_OP); FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP); +FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_MKLDNN_OP); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); diff --git a/paddle/fluid/operators/mkldnn_activation_op.h b/paddle/fluid/operators/mkldnn_activation_op.h index f26a165b5a59f01f864d62bbf798f4cbffa65371..de8daed1706336575daf99c9197be29eaa5473d1 100644 --- a/paddle/fluid/operators/mkldnn_activation_op.h +++ b/paddle/fluid/operators/mkldnn_activation_op.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" @@ -61,9 +63,9 @@ class MKLDNNActivationGradKernel }; namespace { // NOLINT -framework::OpKernelType GetKernelType( - const framework::ExecutionContext& ctx, - const framework::OperatorWithKernel& oper) { +framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel& oper, + const std::string& name) { framework::LibraryType library{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && @@ -73,7 +75,7 @@ framework::OpKernelType GetKernelType( #endif framework::DataLayout layout = framework::DataLayout::kAnyLayout; return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), + framework::ToDataType(ctx.Input(name)->type()), ctx.GetPlace(), layout, library); } } // anonymous namespace @@ -89,7 +91,7 @@ class ActivationWithMKLDNNOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return GetKernelType(ctx, *this); + return GetKernelType(ctx, *this, "X"); } }; @@ -103,7 +105,7 @@ class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return GetKernelType(ctx, *this); + return GetKernelType(ctx, *this, "Out"); } };