diff --git a/paddle/phi/kernels/onednn/squeeze_kernel.cc b/paddle/phi/kernels/onednn/squeeze_kernel.cc old mode 100644 new mode 100755 index eb7663f8e41b2d9177e20b28e9d467fcb5854465..9f2b9a8a4423c5c64130fedf5f4f2e662295c7bc --- a/paddle/phi/kernels/onednn/squeeze_kernel.cc +++ b/paddle/phi/kernels/onednn/squeeze_kernel.cc @@ -68,9 +68,13 @@ void SqueezeWithXShapeKernel(const Context& dev_ctx, const IntArray& axes, DenseTensor* out, DenseTensor* xshape) { - auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size()); - auto out_dims = out->dims(); - ExecuteSqueeze(dev_ctx, x, x_dims, out_dims, out); + if (xshape == nullptr) { + SqueezeKernel(dev_ctx, x, axes, out); + } else { + auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size()); + auto out_dims = out->dims(); + ExecuteSqueeze(dev_ctx, x, x_dims, out_dims, out); + } } } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py old mode 100644 new mode 100755 index 4aeaf625d289eefc8972c61cd53273c90bb89ccd..643eaa6d0d5240611bcf0548031e464ac5ab1b2e --- a/python/paddle/fluid/tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_squeeze2_mkldnn_op.py @@ -111,6 +111,16 @@ class TestSqueeze2OneDNNOp3(TestSqueeze2OneDNNOp): self.new_shape = (25, 1, 4) +class TestSqueeze2OneDNNOp4(TestSqueeze2OneDNNOp): + def set_outputs(self): + self.outputs = {"Out": self.x.reshape(self.new_shape)} + + def init_test_case(self): + self.ori_shape = (25, 1, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (25, 1, 4) + + class TestSqueezeOneDNNOp3(TestSqueezeOneDNNOp): def init_test_case(self): self.ori_shape = (25, 1, 1, 4, 1)