diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index 1138d5113929329462a7ea6ccd01f1b7bc375322..4a55945936ed57e4453ef39b55e6a3d8db4784a8 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -15,15 +15,6 @@ limitations under the License. */ #include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/platform/mkldnn_reuse.h" -namespace paddle { -namespace framework { -class Tensor; -} // namespace framework -namespace platform { -class MKLDNNDeviceContext; -} // namespace platform -} // namespace paddle - namespace paddle { namespace operators { @@ -74,22 +65,34 @@ class SoftmaxMKLDNNHandler } } - SoftmaxMKLDNNHandler(const std::vector& dims, - const MKLDNNMemoryFormat fmt, - const MKLDNNMemoryFormat diff_fmt, const int& axis, - const platform::MKLDNNDeviceContext& dev_ctx, - platform::Place cpu_place, const std::string& uniq_name) + SoftmaxMKLDNNHandler(const framework::ExecutionContext& ctx, + const MKLDNNDeviceContext& dev_ctx, + platform::Place cpu_place, const Tensor* out, + const Tensor* out_grad, Tensor* in_x_grad, + const std::string& unique_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dev_ctx, dims, uniq_name)) { - auto data_softmax_md = - mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); - auto diff_softmax_md = - mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), diff_fmt); - - this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md, - axis); + platform::CreateKey(dev_ctx, framework::vectorize(out->dims()), + unique_name)) { + if (!this->isBwdCached()) { + PADDLE_ENFORCE_EQ( + out_grad->dims(), in_x_grad->dims(), + platform::errors::InvalidArgument("The shape of softmax_grad's input " + "and output must be identical.")); + + auto dims = out_grad->dims(); // input and output share the same shape + const int axis = CanonicalAxis(ctx.Attr("axis"), dims.size()); + auto softmax_tz = framework::vectorize(dims); + + auto data_softmax_md = MKLDNNMemDesc( + softmax_tz, platform::MKLDNNGetDataType(), out->format()); + auto diff_softmax_md = MKLDNNMemDesc( + softmax_tz, platform::MKLDNNGetDataType(), out_grad->format()); + + this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md, + axis); + } } }; @@ -145,27 +148,15 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel { "Operator DNNL SoftmaxGrad must use CPUPlace")); auto& dev_ctx = ctx.template device_context(); const Tensor* output = ctx.Input("Out"); - auto* dout = ctx.template Input(framework::GradVarName("Out")); - auto* dx = - ctx.template Output(framework::GradVarName("X")); - - PADDLE_ENFORCE_EQ( - dout->dims(), dx->dims(), - platform::errors::InvalidArgument( - "The shape of softmax_grad's input and output must be identical.")); - - auto dims = dout->dims(); // input and output share the same shape - const int axis = CanonicalAxis(ctx.Attr("axis"), dims.size()); - - auto softmax_tz = paddle::framework::vectorize(dims); + auto* out_grad = ctx.template Input(framework::GradVarName("Out")); + auto* in_x_grad = ctx.template Output(framework::GradVarName("X")); - SoftmaxMKLDNNHandler handler(softmax_tz, output->format(), - dout->format(), axis, dev_ctx, - ctx.GetPlace(), ctx.InputName("Out")); + SoftmaxMKLDNNHandler handler(ctx, dev_ctx, ctx.GetPlace(), output, + out_grad, in_x_grad, ctx.InputName("Out")); auto dst_memory_p = handler.AcquireDstMemory(output); - auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); - auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); + auto diff_dst_memory_p = handler.AcquireDiffDstMemory(out_grad); + auto diff_src_memory_p = handler.AcquireDiffSrcMemory(in_x_grad); auto softmax_bwd_p = handler.AcquireBackwardPrimitive(); @@ -176,8 +167,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel { {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}}); astream.wait(); - dx->set_layout(framework::DataLayout::kMKLDNN); - dx->set_format(dout->format()); + in_x_grad->set_layout(framework::DataLayout::kMKLDNN); + in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory_p)); } }; } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_softmax_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_softmax_mkldnn_op.py index 9e2229cece75c2074f288be29440f3027da64e5e..13c1883af6184f2971d43c6bbd88496912c5ec67 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_softmax_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_softmax_mkldnn_op.py @@ -129,4 +129,6 @@ class TestSoftmaxMKLDNNPrimitivesAlreadyExist(unittest.TestCase): if __name__ == '__main__': + from paddle import enable_static + enable_static() unittest.main()