diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index e957321e9ca2af425342699f23c538913751b711..51bc534bff27c48d7f24c82057008a2367dd073a 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -25,12 +25,12 @@ using paddle::framework::Tensor; using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNMemDesc; -using mkldnn::memory; // Note: paddle has also "memory" namespace -using mkldnn::primitive; -using mkldnn::prop_kind; -using mkldnn::softmax_backward; -using mkldnn::softmax_forward; -using mkldnn::stream; +using dnnl::memory; // Note: paddle has also "memory" namespace +using dnnl::primitive; +using dnnl::prop_kind; +using dnnl::softmax_backward; +using dnnl::softmax_forward; +using dnnl::stream; using platform::to_void_cast; template @@ -38,19 +38,30 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandlerT { public: - SoftmaxMKLDNNHandler(const std::vector& dims, - const MKLDNNMemoryFormat fmt, const int& axis, - const platform::MKLDNNDeviceContext& dev_ctx, - platform::Place cpu_place, const std::string& uniq_name) + SoftmaxMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine mkldnn_engine, + platform::Place cpu_place, const Tensor* input, + Tensor* output, const int axis, + const std::string uniq_name) : platform::MKLDNNHandlerT( - dev_ctx, dev_ctx.GetEngine(), cpu_place, + dev_ctx, mkldnn_engine, cpu_place, // Softmax may be inplace then uniq_name is no longer unique - platform::CreateKey(dims, axis, uniq_name)) { - auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); - - this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md, - axis); + platform::CreateKey(framework::vectorize(input->dims()), axis, + uniq_name)) { + if (!this->isCached()) { + PADDLE_ENFORCE_EQ( + input->dims(), output->dims(), + platform::errors::InvalidArgument( + "The shape of input and output tensor must be identical.")); + + auto softmax_tz = framework::vectorize(input->dims()); + auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType(), + input->format()); + + this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md, + axis); + } } SoftmaxMKLDNNHandler(const std::vector& dims, @@ -76,30 +87,25 @@ template class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), - "It must use CPUPlace."); auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + const Tensor* input = ctx.Input("X"); Tensor* output = ctx.Output("Out"); - PADDLE_ENFORCE_EQ( - input->dims(), output->dims(), - "The shape of softmax's input and output must be identical."); - - auto dims = input->dims(); // input and output share the same shape - const int axis = CanonicalAxis(ctx.Attr("axis"), dims.size()); - auto softmax_tz = paddle::framework::vectorize(dims); + const int axis = CanonicalAxis(ctx.Attr("axis"), input->dims().size()); - SoftmaxMKLDNNHandler handler(softmax_tz, input->format(), axis, dev_ctx, - ctx.GetPlace(), ctx.OutputName("Out")); + SoftmaxMKLDNNHandler handler(dev_ctx, mkldnn_engine, ctx.GetPlace(), + input, output, axis, ctx.OutputName("Out")); auto softmax_src_memory_p = handler.AcquireSrcMemory(input); - auto softmax_p = handler.AcquireForwardPrimitive(); // For Inplace src and and dst are the same memory object auto softmax_dst_memory_p = input->IsSharedBufferWith(*output) ? softmax_src_memory_p : handler.AcquireDstMemory(output); + auto softmax_p = handler.AcquireForwardPrimitive(); + mkldnn::stream astream(dev_ctx.GetEngine()); softmax_p->execute(astream, {{DNNL_ARG_SRC, *softmax_src_memory_p}, {DNNL_ARG_DST, *softmax_dst_memory_p}}); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index b17c6b79220677db6dfca7fc94456cefb8e5189d..4248a2b859f63817291089524794804f6dfdcd04 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -114,7 +114,9 @@ class MKLDNNHandlerT { const std::string key_pd = key_common_ + "@forward_pd"; fwd_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_pd)); - return (fwd_pd_ != nullptr); + + const std::string key_p = key_ + "@forward_p"; + return (dev_ctx_.GetBlob(key_p) != nullptr); } template @@ -367,7 +369,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { BinaryMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine engine, platform::Place cpu_place, const Tensor* x, const Tensor* y, Tensor* z, - const std::string uniq_name) + const std::string& uniq_name) : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey(framework::vectorize(x->dims()), uniq_name)) {