diff --git a/paddle/phi/kernels/impl/solve_kernel_impl.h b/paddle/phi/kernels/impl/solve_kernel_impl.h index b0e6b2b6cc02591b1e3674da9ae6318cdddaadb7..d5ecfdff21a998138a779d65c45ace7321940483 100644 --- a/paddle/phi/kernels/impl/solve_kernel_impl.h +++ b/paddle/phi/kernels/impl/solve_kernel_impl.h @@ -169,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx, out_tmp.Resize(out->dims()); out_tmp = *out; - phi::SqueezeInferKernel(dev_ctx, out_tmp, {-1}, out); + phi::Squeeze(dev_ctx, out_tmp, {-1}, out); } else { PADDLE_ENFORCE_EQ( x_dim[x_dim_size - 1], diff --git a/paddle/phi/kernels/squeeze_kernel.cc b/paddle/phi/kernels/squeeze_kernel.cc index a95a8cc9a2ff206f926b09c19d4b7db392bfd379..d36e42c8126619b0d7d7716786ff2b6458a559a5 100644 --- a/paddle/phi/kernels/squeeze_kernel.cc +++ b/paddle/phi/kernels/squeeze_kernel.cc @@ -25,11 +25,7 @@ void SqueezeInferKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& axes, DenseTensor* out) { - auto x_dims = x.dims(); - std::vector tmp(axes.GetData().begin(), axes.GetData().end()); - auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true); - out->Resize(out_dims); - + auto out_dims = out->dims(); dev_ctx.template Alloc(out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); out->Resize(out_dims); // copy will reset the dims. diff --git a/paddle/phi/kernels/squeeze_kernel.h b/paddle/phi/kernels/squeeze_kernel.h index 8114969ea7de7d97c5e5e92ec9d4aefddeb6defe..fcd994de7bff40fe513515dff4e93ab1fa8a0853 100644 --- a/paddle/phi/kernels/squeeze_kernel.h +++ b/paddle/phi/kernels/squeeze_kernel.h @@ -17,6 +17,7 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" namespace phi { @@ -33,4 +34,14 @@ void SqueezeKernel(const Context& dev_ctx, DenseTensor* out, DenseTensor* xshape); +template +void Squeeze(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out) { + MetaTensor meta_out(out); + SqueezeInferMeta(x, axes, &meta_out); + SqueezeInferKernel(dev_ctx, x, axes, out); +} + } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_squeeze2_op.py b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py index b8374da08727aaa49bac1e5264a354c2321bfb15..166864bd5e3df31eb90dd010f109af10a6fbd73f 100755 --- a/python/paddle/fluid/tests/unittests/test_squeeze2_op.py +++ b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py @@ -155,5 +155,37 @@ class TestSqueeze2AxesTensorList(UnittestBase): self.assertEqual(infer_out.shape, (2, 3, 10)) +# test api +class TestSqueezeAPI(unittest.TestCase): + def setUp(self): + self.executed_api() + + def executed_api(self): + self.squeeze = paddle.squeeze + + def test_api(self): + paddle.disable_static() + input_data = np.random.random([3, 2, 1]).astype("float32") + x = paddle.to_tensor(input_data) + out = self.squeeze(x, axis=2) + out.backward() + + self.assertEqual(out.shape, [3, 2]) + + paddle.enable_static() + + def test_error(self): + def test_axes_type(): + x2 = paddle.static.data(name="x2", shape=[2, 1, 25], dtype="int32") + self.squeeze(x2, axis=2.1) + + self.assertRaises(TypeError, test_axes_type) + + +class TestSqueezeInplaceAPI(TestSqueezeAPI): + def executed_api(self): + self.squeeze = paddle.squeeze_ + + if __name__ == "__main__": unittest.main()