From cef601bf643b82e84634d61a42836122098482c7 Mon Sep 17 00:00:00 2001 From: yeliang2258 <30516196+yeliang2258@users.noreply.github.com> Date: Tue, 20 Dec 2022 16:12:53 +0800 Subject: [PATCH] Fix ONEDNN squeeze kernel when Xshape does not exist (#49071) * fix onednn squeeze bug * add test * update kernel --- paddle/phi/kernels/onednn/squeeze_kernel.cc | 10 +++++++--- .../tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py | 10 ++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) mode change 100644 => 100755 paddle/phi/kernels/onednn/squeeze_kernel.cc mode change 100644 => 100755 python/paddle/fluid/tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py diff --git a/paddle/phi/kernels/onednn/squeeze_kernel.cc b/paddle/phi/kernels/onednn/squeeze_kernel.cc old mode 100644 new mode 100755 index eb7663f8e4..9f2b9a8a44 --- a/paddle/phi/kernels/onednn/squeeze_kernel.cc +++ b/paddle/phi/kernels/onednn/squeeze_kernel.cc @@ -68,9 +68,13 @@ void SqueezeWithXShapeKernel(const Context& dev_ctx, const IntArray& axes, DenseTensor* out, DenseTensor* xshape) { - auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size()); - auto out_dims = out->dims(); - ExecuteSqueeze(dev_ctx, x, x_dims, out_dims, out); + if (xshape == nullptr) { + SqueezeKernel(dev_ctx, x, axes, out); + } else { + auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size()); + auto out_dims = out->dims(); + ExecuteSqueeze(dev_ctx, x, x_dims, out_dims, out); + } } } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py old mode 100644 new mode 100755 index 4aeaf625d2..643eaa6d0d --- a/python/paddle/fluid/tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py @@ -111,6 +111,16 @@ class TestSqueeze2OneDNNOp3(TestSqueeze2OneDNNOp): self.new_shape = (25, 1, 4) +class TestSqueeze2OneDNNOp4(TestSqueeze2OneDNNOp): + def set_outputs(self): + self.outputs = {"Out": self.x.reshape(self.new_shape)} + + def init_test_case(self): + self.ori_shape = (25, 1, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (25, 1, 4) + + class TestSqueezeOneDNNOp3(TestSqueezeOneDNNOp): def init_test_case(self): self.ori_shape = (25, 1, 1, 4, 1) -- GitLab