From cc511f243975350075ab511ebeb4dd8acba5f76c Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Thu, 9 Mar 2023 11:43:01 +0800 Subject: [PATCH] support ONEDNN 0D for full_kernel (#51265) --- paddle/phi/backends/onednn/onednn_helper.h | 3 +++ paddle/phi/kernels/onednn/transpose_grad_kernel.cc | 2 +- paddle/phi/kernels/onednn/transpose_kernel.cc | 2 +- .../tests/unittests/mkldnn/test_transpose_mkldnn_op.py | 6 ++++++ 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/paddle/phi/backends/onednn/onednn_helper.h b/paddle/phi/backends/onednn/onednn_helper.h index 040122b692d..c9511e89a8d 100644 --- a/paddle/phi/backends/onednn/onednn_helper.h +++ b/paddle/phi/backends/onednn/onednn_helper.h @@ -69,6 +69,9 @@ inline OneDNNMemoryFormat OneDNNFormatForSize(size_t dims_size, inline dnnl::memory::format_tag GetPlainOneDNNFormat(int tensor_rank) { switch (tensor_rank) { + case 0: + // use 1D to represent 0D + return dnnl::memory::format_tag::a; case 1: return dnnl::memory::format_tag::a; case 2: diff --git a/paddle/phi/kernels/onednn/transpose_grad_kernel.cc b/paddle/phi/kernels/onednn/transpose_grad_kernel.cc index dafbb75dc07..e63faf26009 100644 --- a/paddle/phi/kernels/onednn/transpose_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/transpose_grad_kernel.cc @@ -30,7 +30,7 @@ void TransposeGradKernel(const Context& dev_ctx, const auto& onednn_engine = dev_ctx.GetEngine(); - if (axis.size() == 1) { + if (axis.size() == 1 || axis.size() == 0) { Copy(dev_ctx, out_grad, out_grad.place(), false, x_grad); x_grad->set_mem_desc(out_grad.mem_desc()); return; diff --git a/paddle/phi/kernels/onednn/transpose_kernel.cc b/paddle/phi/kernels/onednn/transpose_kernel.cc index 83953262527..136fd0a0e57 100644 --- a/paddle/phi/kernels/onednn/transpose_kernel.cc +++ b/paddle/phi/kernels/onednn/transpose_kernel.cc @@ -112,7 +112,7 @@ void TransposeKernel(const Context& dev_ctx, SetInMemDescWithLogicalLayoutFusesSupport( dev_ctx, const_cast(&x), x.mem_desc()); - if (axis.size() == 1) { + if (axis.size() == 1 || axis.size() == 0) { Copy(dev_ctx, x, x.place(), false, out); out->set_mem_desc(x.mem_desc()); return; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_transpose_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_transpose_mkldnn_op.py index e1ef8a7a12f..94c75559ecd 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_transpose_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_transpose_mkldnn_op.py @@ -87,5 +87,11 @@ class TestCase4(TestTransposeMKLDNN): self.axis = (4, 2, 3, 1, 0, 5) +class TestCase_ZeroDim(TestTransposeMKLDNN): + def initTestCase(self): + self.shape = () + self.axis = () + + if __name__ == '__main__': unittest.main() -- GitLab