From 11e34ae0066e4a1fe7e2e168d3f82c6a77c82d86 Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Thu, 19 Jan 2023 16:23:59 +0800 Subject: [PATCH] Fix paddle.queeze_ bug (#49903) * fix queeze_ bug * fix slove use squeeze_kernel * fix slove use squeeze_kernel * fix slove use squeeze_kernel * add test case --- paddle/phi/kernels/impl/solve_kernel_impl.h | 2 +- paddle/phi/kernels/squeeze_kernel.cc | 6 +--- paddle/phi/kernels/squeeze_kernel.h | 11 +++++++ .../fluid/tests/unittests/test_squeeze2_op.py | 32 +++++++++++++++++++ 4 files changed, 45 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/impl/solve_kernel_impl.h b/paddle/phi/kernels/impl/solve_kernel_impl.h index b0e6b2b6cc..d5ecfdff21 100644 --- a/paddle/phi/kernels/impl/solve_kernel_impl.h +++ b/paddle/phi/kernels/impl/solve_kernel_impl.h @@ -169,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx, out_tmp.Resize(out->dims()); out_tmp = *out; - phi::SqueezeInferKernel(dev_ctx, out_tmp, {-1}, out); + phi::Squeeze(dev_ctx, out_tmp, {-1}, out); } else { PADDLE_ENFORCE_EQ( x_dim[x_dim_size - 1], diff --git a/paddle/phi/kernels/squeeze_kernel.cc b/paddle/phi/kernels/squeeze_kernel.cc index a95a8cc9a2..d36e42c812 100644 --- a/paddle/phi/kernels/squeeze_kernel.cc +++ b/paddle/phi/kernels/squeeze_kernel.cc @@ -25,11 +25,7 @@ void SqueezeInferKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& axes, DenseTensor* out) { - auto x_dims = x.dims(); - std::vector tmp(axes.GetData().begin(), axes.GetData().end()); - auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true); - out->Resize(out_dims); - + auto out_dims = out->dims(); dev_ctx.template Alloc(out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); out->Resize(out_dims); // copy will reset the dims. diff --git a/paddle/phi/kernels/squeeze_kernel.h b/paddle/phi/kernels/squeeze_kernel.h index 8114969ea7..fcd994de7b 100644 --- a/paddle/phi/kernels/squeeze_kernel.h +++ b/paddle/phi/kernels/squeeze_kernel.h @@ -17,6 +17,7 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" namespace phi { @@ -33,4 +34,14 @@ void SqueezeKernel(const Context& dev_ctx, DenseTensor* out, DenseTensor* xshape); +template +void Squeeze(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out) { + MetaTensor meta_out(out); + SqueezeInferMeta(x, axes, &meta_out); + SqueezeInferKernel(dev_ctx, x, axes, out); +} + } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_squeeze2_op.py b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py index b8374da087..166864bd5e 100755 --- a/python/paddle/fluid/tests/unittests/test_squeeze2_op.py +++ b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py @@ -155,5 +155,37 @@ class TestSqueeze2AxesTensorList(UnittestBase): self.assertEqual(infer_out.shape, (2, 3, 10)) +# test api +class TestSqueezeAPI(unittest.TestCase): + def setUp(self): + self.executed_api() + + def executed_api(self): + self.squeeze = paddle.squeeze + + def test_api(self): + paddle.disable_static() + input_data = np.random.random([3, 2, 1]).astype("float32") + x = paddle.to_tensor(input_data) + out = self.squeeze(x, axis=2) + out.backward() + + self.assertEqual(out.shape, [3, 2]) + + paddle.enable_static() + + def test_error(self): + def test_axes_type(): + x2 = paddle.static.data(name="x2", shape=[2, 1, 25], dtype="int32") + self.squeeze(x2, axis=2.1) + + self.assertRaises(TypeError, test_axes_type) + + +class TestSqueezeInplaceAPI(TestSqueezeAPI): + def executed_api(self): + self.squeeze = paddle.squeeze_ + + if __name__ == "__main__": unittest.main() -- GitLab