diff --git a/paddle/phi/kernels/impl/solve_kernel_impl.h b/paddle/phi/kernels/impl/solve_kernel_impl.h index 4120823a9d2e91adc69423b49e809892c4acd402..d5ecfdff21a998138a779d65c45ace7321940483 100644 --- a/paddle/phi/kernels/impl/solve_kernel_impl.h +++ b/paddle/phi/kernels/impl/solve_kernel_impl.h @@ -169,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx, out_tmp.Resize(out->dims()); out_tmp = *out; - phi::SqueezeKernel(dev_ctx, out_tmp, {-1}, out); + phi::Squeeze(dev_ctx, out_tmp, {-1}, out); } else { PADDLE_ENFORCE_EQ( x_dim[x_dim_size - 1], diff --git a/paddle/phi/kernels/impl/squeeze_kernel_impl.h b/paddle/phi/kernels/impl/squeeze_kernel_impl.h index cb5eed521cd226d6d390c86a32092bebc1e6905c..a20981d78979bdad848288d46639f29a5e51394e 100644 --- a/paddle/phi/kernels/impl/squeeze_kernel_impl.h +++ b/paddle/phi/kernels/impl/squeeze_kernel_impl.h @@ -23,11 +23,7 @@ void SqueezeKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& axes, DenseTensor* out) { - auto x_dims = x.dims(); - std::vector tmp(axes.GetData().begin(), axes.GetData().end()); - auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true); - out->Resize(out_dims); - + auto out_dims = out->dims(); dev_ctx.template Alloc(out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); out->Resize(out_dims); // copy will reset the dims. diff --git a/paddle/phi/kernels/squeeze_kernel.h b/paddle/phi/kernels/squeeze_kernel.h index 7e5a1b0775ab662b9b88e62fcd23320fd9724d2e..03d708b312089827f8032ef9b40f83bb2abd255b 100644 --- a/paddle/phi/kernels/squeeze_kernel.h +++ b/paddle/phi/kernels/squeeze_kernel.h @@ -17,6 +17,7 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" namespace phi { @@ -33,4 +34,14 @@ void SqueezeWithXShapeKernel(const Context& dev_ctx, DenseTensor* out, DenseTensor* xshape); +template +void Squeeze(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out) { + MetaTensor meta_out(out); + SqueezeInferMeta(x, axes, &meta_out); + SqueezeKernel(dev_ctx, x, axes, out); +} + } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_squeeze2_op.py b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py index 9d3cd7fecc94334bfd0fdcc863ee8d016626f5b7..5a71efafbdbfc17dfd5315e05fb1b7d786a534dd 100755 --- a/python/paddle/fluid/tests/unittests/test_squeeze2_op.py +++ b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py @@ -28,7 +28,6 @@ paddle.enable_static() # Correct: General. class TestSqueezeOp(OpTest): - def setUp(self): self.op_type = "squeeze2" self.python_api = paddle.squeeze @@ -40,7 +39,7 @@ class TestSqueezeOp(OpTest): self.init_attrs() self.outputs = { "Out": self.inputs["X"].reshape(self.new_shape), - "XShape": np.random.random(self.ori_shape).astype("float64") + "XShape": np.random.random(self.ori_shape).astype("float64"), } def test_check_output(self): @@ -60,7 +59,6 @@ class TestSqueezeOp(OpTest): # Correct: There is mins axis. class TestSqueezeOp1(TestSqueezeOp): - def init_test_case(self): self.ori_shape = (1, 20, 1, 5) self.axes = (0, -2) @@ -69,7 +67,6 @@ class TestSqueezeOp1(TestSqueezeOp): # Correct: No axes input. class TestSqueezeOp2(TestSqueezeOp): - def init_test_case(self): self.ori_shape = (1, 20, 1, 5) self.axes = () @@ -78,7 +75,6 @@ class TestSqueezeOp2(TestSqueezeOp): # Correct: Just part of axes be squeezed. class TestSqueezeOp3(TestSqueezeOp): - def init_test_case(self): self.ori_shape = (6, 1, 5, 1, 4, 1) self.axes = (1, -1) @@ -86,7 +82,6 @@ class TestSqueezeOp3(TestSqueezeOp): class TestSqueeze2AxesTensor(UnittestBase): - def init_info(self): self.shapes = [[2, 3, 4]] self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor') @@ -123,7 +118,6 @@ class TestSqueeze2AxesTensor(UnittestBase): class TestSqueeze2AxesTensorList(UnittestBase): - def init_info(self): self.shapes = [[2, 3, 4]] self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor') @@ -140,7 +134,7 @@ class TestSqueeze2AxesTensorList(UnittestBase): # axes is a list[Variable] axes = [ paddle.full([1], 0, dtype='int32'), - paddle.full([1], 2, dtype='int32') + paddle.full([1], 2, dtype='int32'), ] out = paddle.squeeze(feat, axes) out2 = paddle.fluid.layers.squeeze(feat, axes) @@ -162,5 +156,37 @@ class TestSqueeze2AxesTensorList(UnittestBase): self.assertEqual(infer_out.shape, (2, 3, 10)) +# test api +class TestSqueezeAPI(unittest.TestCase): + def setUp(self): + self.executed_api() + + def executed_api(self): + self.squeeze = paddle.squeeze + + def test_api(self): + paddle.disable_static() + input_data = np.random.random([3, 2, 1]).astype("float32") + x = paddle.to_tensor(input_data) + out = self.squeeze(x, axis=2) + out.backward() + + self.assertEqual(out.shape, [3, 2]) + + paddle.enable_static() + + def test_error(self): + def test_axes_type(): + x2 = paddle.static.data(name="x2", shape=[2, 1, 25], dtype="int32") + self.squeeze(x2, axis=2.1) + + self.assertRaises(TypeError, test_axes_type) + + +class TestSqueezeInplaceAPI(TestSqueezeAPI): + def executed_api(self): + self.squeeze = paddle.squeeze_ + + if __name__ == "__main__": unittest.main()