diff --git a/paddle/phi/backends/onednn/onednn_helper.h b/paddle/phi/backends/onednn/onednn_helper.h index 040122b692da64982e5a8d2f6cbca78fe7c4c644..c9511e89a8d54312dd78e86104f27264e9106358 100644 --- a/paddle/phi/backends/onednn/onednn_helper.h +++ b/paddle/phi/backends/onednn/onednn_helper.h @@ -69,6 +69,9 @@ inline OneDNNMemoryFormat OneDNNFormatForSize(size_t dims_size, inline dnnl::memory::format_tag GetPlainOneDNNFormat(int tensor_rank) { switch (tensor_rank) { + case 0: + // use 1D to represent 0D + return dnnl::memory::format_tag::a; case 1: return dnnl::memory::format_tag::a; case 2: diff --git a/paddle/phi/kernels/onednn/transpose_grad_kernel.cc b/paddle/phi/kernels/onednn/transpose_grad_kernel.cc index dafbb75dc07ac573edcddb442b0d230600315ec0..e63faf26009647fd1ba74edae47094ab13aaf3ed 100644 --- a/paddle/phi/kernels/onednn/transpose_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/transpose_grad_kernel.cc @@ -30,7 +30,7 @@ void TransposeGradKernel(const Context& dev_ctx, const auto& onednn_engine = dev_ctx.GetEngine(); - if (axis.size() == 1) { + if (axis.size() == 1 || axis.size() == 0) { Copy(dev_ctx, out_grad, out_grad.place(), false, x_grad); x_grad->set_mem_desc(out_grad.mem_desc()); return; diff --git a/paddle/phi/kernels/onednn/transpose_kernel.cc b/paddle/phi/kernels/onednn/transpose_kernel.cc index 83953262527ae6d17ee4f2836270246ee45a5d08..136fd0a0e573481ca01225abad23601a752530cf 100644 --- a/paddle/phi/kernels/onednn/transpose_kernel.cc +++ b/paddle/phi/kernels/onednn/transpose_kernel.cc @@ -112,7 +112,7 @@ void TransposeKernel(const Context& dev_ctx, SetInMemDescWithLogicalLayoutFusesSupport( dev_ctx, const_cast(&x), x.mem_desc()); - if (axis.size() == 1) { + if (axis.size() == 1 || axis.size() == 0) { Copy(dev_ctx, x, x.place(), false, out); out->set_mem_desc(x.mem_desc()); return; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_transpose_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_transpose_mkldnn_op.py index e1ef8a7a12f3ff7d4859126678a756371ad866d1..94c75559ecd8670ed4369f341473a01a4c056e6b 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_transpose_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_transpose_mkldnn_op.py @@ -87,5 +87,11 @@ class TestCase4(TestTransposeMKLDNN): self.axis = (4, 2, 3, 1, 0, 5) +class TestCase_ZeroDim(TestTransposeMKLDNN): + def initTestCase(self): + self.shape = () + self.axis = () + + if __name__ == '__main__': unittest.main()