diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index cd53a07aca0f63c1876c632653a16d1e4dddc632..2a73458bfeb0e2808cc1a910318679e6a6cde231 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include #include "mkldnn.hpp" #include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/platform/mkldnn_reuse.h" @@ -38,37 +39,37 @@ class SoftmaxMKLDNNHandler mkldnn::softmax_backward> { public: SoftmaxMKLDNNHandler(const std::vector& dims, - const MKLDNNMemoryFormat fmt, + const MKLDNNMemoryFormat fmt, const int& axis, const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place, const std::string& uniq_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, uniq_name)) { + platform::CreateKey(dims, axis, uniq_name)) { auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md, - 1 /*dim: C*/); + axis); } SoftmaxMKLDNNHandler(const std::vector& dims, const MKLDNNMemoryFormat fmt, - const MKLDNNMemoryFormat diff_fmt, + const MKLDNNMemoryFormat diff_fmt, const int& axis, const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place, const std::string& uniq_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dims, uniq_name)) { + platform::CreateKey(dims, axis, uniq_name)) { auto data_softmax_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); auto diff_softmax_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), diff_fmt); this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, - data_softmax_md, 1 /*dim: C*/); + data_softmax_md, axis); this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md, - 1 /* dim: C*/); + axis); } }; @@ -85,18 +86,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { input->dims(), output->dims(), "The shape of softmax's input and output must be identical."); - // flatten input and output to 2-D matrixs auto dims = input->dims(); // input and output share the same shape - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + const int axis = CanonicalAxis(ctx.Attr("axis"), dims.size()); - auto src_tz = paddle::framework::vectorize(flattened_dims); - auto dst_tz = src_tz; - // Same memory descriptor to be used for input and output - memory::dims softmax_tz = {src_tz[0], src_tz[1]}; + auto softmax_tz = paddle::framework::vectorize(dims); - SoftmaxMKLDNNHandler handler(softmax_tz, MKLDNNMemoryFormat::nc, dev_ctx, + SoftmaxMKLDNNHandler handler(softmax_tz, input->format(), axis, dev_ctx, ctx.GetPlace(), ctx.op().Output("Out")); - // Currently only NC data format is supported + auto softmax_src_memory_p = handler.AcquireSrcMemory(input); auto softmax_dst_memory_p = handler.AcquireDstMemory(output); auto softmax_p = handler.AcquireForwardPrimitive(*softmax_src_memory_p, @@ -105,14 +102,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { std::vector pipeline{*softmax_p}; stream(stream::kind::eager).submit(pipeline).wait(); - T* output_data = output->mutable_data(ctx.GetPlace()); const bool is_test = ctx.Attr("is_test"); if (!is_test) { - T threshold = exp(-64); - for (int i = 0; i < dst_tz[0] * dst_tz[1]; ++i) { - output_data[i] = - output_data[i] < threshold ? threshold : output_data[i]; - } + T* output_data = output->mutable_data(ctx.GetPlace()); + int size = std::accumulate(begin(softmax_tz), end(softmax_tz), 1, + std::multiplies()); + std::for_each(output_data, &output_data[size], [](T& val) { + val = std::max(val, static_cast(exp(-64))); + }); } output->set_layout(framework::DataLayout::kMKLDNN); @@ -139,31 +136,26 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel { "The shape of softmax_grad's input and output must be identical."); auto dims = dout->dims(); // input and output share the same shape - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + const int axis = CanonicalAxis(ctx.Attr("axis"), dims.size()); - std::vector dst_tz = paddle::framework::vectorize(flattened_dims); - std::vector src_tz(dst_tz); + std::vector softmax_tz = paddle::framework::vectorize(dims); - // Same memory descriptor to be used for input and output - memory::dims softmax_tz = {src_tz[0], src_tz[1]}; - - // TODO(jczaja): Add layouts support when there is a need to do so - // Two dimensional softmax does support NC format - // Normalization is made after innermost dimension eg. C out of NC - SoftmaxMKLDNNHandler handler(softmax_tz, MKLDNNMemoryFormat::nc, - MKLDNNMemoryFormat::nc, dev_ctx, + SoftmaxMKLDNNHandler handler(softmax_tz, output->format(), + dout->format(), axis, dev_ctx, ctx.GetPlace(), ctx.op().Input("Out")); auto dst_memory_p = handler.AcquireDstMemory(output); auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); - // Get primitve from device context auto softmax_bwd_p = handler.AcquireBackwardPrimitive( *dst_memory_p, *diff_dst_memory_p, *diff_src_memory_p); std::vector pipeline{*softmax_bwd_p}; stream(stream::kind::eager).submit(pipeline).wait(); + + dx->set_layout(framework::DataLayout::kMKLDNN); + dx->set_format(dout->format()); } }; } // namespace operators diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 17b944044654223450cd7baba04d0d5b8bf7c0f4..9d73a19197c29fae29728cd6ab770bc0cc7a3ab1 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -47,10 +47,8 @@ class SoftmaxOp : public framework::OperatorWithKernel { "R is the rank of Input(X)."); auto use_cudnn = ctx->Attrs().Get("use_cudnn"); - auto use_mkldnn = ctx->Attrs().Get("use_mkldnn"); if (axis != rank_x - 1 && axis != -1) { PADDLE_ENFORCE(!use_cudnn, "CUDNN kernel only support axis as -1."); - PADDLE_ENFORCE(!use_mkldnn, "MKLDNN kernel only support axis as -1."); } ctx->SetOutputDim("Out", ctx->GetInputDim("X")); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_softmax_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_softmax_mkldnn_op.py index 748b77f2bf48f450426d3ea918138a7db8df78f0..11d79d35944005009aadeddd19dc35bf561de6cd 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_softmax_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_softmax_mkldnn_op.py @@ -18,7 +18,7 @@ import unittest import numpy as np from paddle.fluid.tests.unittests.op_test import OpTest import paddle.fluid.core as core -from paddle.fluid.tests.unittests.test_softmax_op import TestSoftmaxOp, stable_softmax +from paddle.fluid.tests.unittests.test_softmax_op import * from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd @@ -27,9 +27,29 @@ class TestSoftmaxMKLDNNOp(TestSoftmaxOp): self.use_mkldnn = True -class TestSoftmaxMKLDNNOp2(TestSoftmaxMKLDNNOp): - def get_x_shape(self): - return [2, 3, 4, 5] +class TestSoftmaxMKLDNNOp2(TestSoftmaxOp2): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestSoftmaxMKLDNNOp3(TestSoftmaxOp3): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestSoftmaxMKLDNNOp4(TestSoftmaxOp4): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestSoftmaxMKLDNNOp5(TestSoftmaxOp5): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestSoftmaxMKLDNNOp6(TestSoftmaxOp6): + def init_kernel_type(self): + self.use_mkldnn = True # Check if primitives already exist in backward diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 8b071260285a1ff50e3c49ec0ac84f388fff97bf..ea14648e2015da5ac715cb4c74f51b097cf5a3d0 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -103,7 +103,7 @@ class TestSoftmaxOp5(TestSoftmaxOp): return 2 -class TestSoftmaxOp5(TestSoftmaxOp): +class TestSoftmaxOp6(TestSoftmaxOp): def get_x_shape(self): return [2, 3, 4, 5]