diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index e45964aadc72d45793846a896e05bff12ec68f9a..89face8faaeed8c306ebd482dfb5d4371a92b6a3 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -25,8 +25,8 @@ namespace operators { using framework::DataLayout; using framework::Tensor; using mkldnn::memory; -using mkldnn::reorder; using mkldnn::primitive; +using mkldnn::reorder; using mkldnn::stream; using mkldnn::sum; @@ -34,51 +34,29 @@ template class EltwiseAddMKLDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = + const auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); + const auto* x = ctx.Input("X"); + const auto* y = ctx.Input("Y"); auto* z = ctx.Output("Out"); - PADDLE_ENFORCE_EQ( - x->layout(), DataLayout::kMKLDNN, - platform::errors::InvalidArgument("Wrong layout set for X tensor")); - PADDLE_ENFORCE_NE( - x->format(), MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument("Wrong format set for X tensor")); - - PADDLE_ENFORCE_EQ( - y->layout(), DataLayout::kMKLDNN, - platform::errors::InvalidArgument("Wrong layout set for Y tensor")); - PADDLE_ENFORCE_NE( - y->format(), MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument("Wrong format set for Y tensor")); - - auto src_x_tz = framework::vectorize(x->dims()); - auto src_y_tz = framework::vectorize(y->dims()); - auto dst_tz = framework::vectorize(z->dims()); - - // Currently MKL-DNN kernel supports only Z <- X + Y, shape(X) == shape(Y) - // TODO(jczaja): Binary primitive support broadcasting, so we can support - // this in kernel platform::BinaryMKLDNNHandler handler( - dnnl::algorithm::binary_add, src_x_tz, x->format(), y->format(), - dev_ctx, ctx.GetPlace(), ctx.OutputName("Out")); + dev_ctx, mkldnn_engine, ctx.GetPlace(), x, y, z, ctx.OutputName("Out")); - auto src_x_memory = handler.AcquireSrcMemory(x); - auto src_y_memory = handler.AcquireSecondSrcMemory(y); + const auto src_x_memory = handler.AcquireSrcMemory(x); + const auto src_y_memory = handler.AcquireSecondSrcMemory(y); // For Inplace src and and dst are the same memory object - auto dst_memory = + const auto dst_memory = x->IsSharedBufferWith(*z) ? src_x_memory : handler.AcquireDstMemory(z); - auto binary_prim = handler.AcquireForwardPrimitive(); + const auto binary_prim = handler.AcquireForwardPrimitive(); mkldnn::stream astream(mkldnn_engine); - std::unordered_map args = { + const std::unordered_map args = { {DNNL_ARG_SRC_0, *src_x_memory}, {DNNL_ARG_SRC_1, *src_y_memory}, {DNNL_ARG_DST, *dst_memory}}; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 2fd7e614cc7b811ca3f49ae8386c29d729eaa697..9a20010532b032d9ac2fc8828023496005832ba0 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -27,6 +27,8 @@ limitations under the License. */ namespace paddle { namespace platform { +using framework::DataLayout; +using framework::Tensor; using user_function = std::function(const float*)>; using memory = mkldnn::memory; @@ -108,6 +110,13 @@ class MKLDNNHandlerT { } protected: + bool isCached() { + const std::string key_pd = key_common_ + "@forward_pd"; + fwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_pd)); + return (fwd_pd_ != nullptr); + } + template void AcquireForwardPrimitiveDescriptor(Args&&... args) { // Forward PD has to be passed to Grad op that @@ -355,22 +364,46 @@ class MKLDNNHandler { template class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { public: - BinaryMKLDNNHandler(const dnnl::algorithm algo, - const std::vector& dims, - const MKLDNNMemoryFormat src0_fmt, - const MKLDNNMemoryFormat src1_fmt, - const platform::MKLDNNDeviceContext& dev_ctx, - platform::Place cpu_place, const std::string& uniq_name) + BinaryMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine engine, platform::Place cpu_place, + const Tensor* x, const Tensor* y, Tensor* z, + const std::string uniq_name) : platform::MKLDNNHandlerT( - dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, uniq_name)) { - // TODO(jczaja): Add function checking if data already exists - auto src0_md = dnnl::memory::desc(dims, MKLDNNGetDataType(), src0_fmt); - auto src1_md = dnnl::memory::desc(dims, MKLDNNGetDataType(), src1_fmt); - auto dst_md = - memory::desc(dims, MKLDNNGetDataType(), MKLDNNMemoryFormat::any); - - this->AcquireForwardPrimitiveDescriptor(algo, src0_md, src1_md, dst_md); + dev_ctx, engine, cpu_place, + platform::CreateKey(framework::vectorize(x->dims()), uniq_name)) { + if (!this->isCached()) { + PADDLE_ENFORCE_EQ( + x->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument("Wrong layout set for X tensor")); + PADDLE_ENFORCE_NE( + x->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument("Wrong format set for X tensor")); + + PADDLE_ENFORCE_EQ( + y->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument("Wrong layout set for Y tensor")); + PADDLE_ENFORCE_NE( + y->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument("Wrong format set for Y tensor")); + + const auto src_x_tz = framework::vectorize(x->dims()); + const auto src_y_tz = framework::vectorize(y->dims()); + const auto dst_tz = framework::vectorize(z->dims()); + + // TODO(jczaja): Add function checking if data already exists + const auto src0_md = dnnl::memory::desc( + src_x_tz, platform::MKLDNNGetDataType(), x->format()); + const auto src1_md = dnnl::memory::desc( + src_y_tz, platform::MKLDNNGetDataType(), y->format()); + const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType(), + MKLDNNMemoryFormat::any); + + // Currently MKL-DNN kernel supports only Z <- X + Y, shape(X) == shape(Y) + // TODO(jczaja): Binary primitive support broadcasting, so we can support + // this in kernel + this->AcquireForwardPrimitiveDescriptor(dnnl::algorithm::binary_add, + src0_md, src1_md, dst_md); + } } std::shared_ptr AcquireSecondSrcMemory(