From 0093aaa694ce0a44d79d5f90a80820971fc8a0d5 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Fri, 6 Jan 2023 11:41:18 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Zero-Dim=E3=80=91Flatten=20support=200?= =?UTF-8?q?d=20tensor=20(#49361)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * flatten op support 0D-tensor * add test in zero dim py * fix shape should be list * short code for ci-coverage * add backward test * simple code for ci coverage * add axis check * add 0D-tensor test in test_flatten_contiguous_range_op.py * add axis error test for Coverage CI * add more test for CI-Coverage * add more test for CI-Coverage --- paddle/fluid/operators/flatten_op.cc | 16 ++++-- paddle/fluid/operators/flatten_op.h | 4 ++ paddle/phi/infermeta/unary.cc | 31 ++++++++--- .../test_flatten_contiguous_range_op.py | 54 ++++++++++++++++++- .../tests/unittests/test_zero_dim_tensor.py | 29 ++++++++++ .../unittests/xpu/test_zero_dim_tensor_xpu.py | 13 +++++ python/paddle/tensor/manipulation.py | 54 +++++++++++-------- 7 files changed, 167 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 54e35a6f03..6aaa251ead 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -41,11 +41,13 @@ class FlattenOp : public framework::OperatorWithKernel { 0, platform::errors::InvalidArgument( "The axis should be greater than or equal to 0.")); - PADDLE_ENFORCE_LE( - axis, - in_dims.size(), - platform::errors::InvalidArgument( - "The axis should be less than or equal to input tensor's rank.")); + if (in_dims.size() > 0) { + PADDLE_ENFORCE_LE( + axis, + in_dims.size(), + platform::errors::InvalidArgument( + "The axis should be less than or equal to input tensor's rank.")); + } const auto &out_dims = GetOutputShape(axis, in_dims); ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); @@ -58,6 +60,10 @@ class FlattenOp : public framework::OperatorWithKernel { static std::vector GetOutputShape(const int axis, const framework::DDim &in_dims) { + if (in_dims.size() == 0) { + return {1}; + } + int64_t outer = 1, inner = 1; for (int i = 0; i < in_dims.size(); ++i) { if (i < axis) { diff --git a/paddle/fluid/operators/flatten_op.h b/paddle/fluid/operators/flatten_op.h index 7fe55b7f22..513716047e 100644 --- a/paddle/fluid/operators/flatten_op.h +++ b/paddle/fluid/operators/flatten_op.h @@ -49,6 +49,10 @@ class FlattenKernel : public framework::OpKernel { static std::vector GetOutputShape(const int axis, const framework::DDim &in_dims) { + if (in_dims.size() == 0) { + return {1}; + } + int64_t outer = 1, inner = 1; for (int i = 0; i < in_dims.size(); ++i) { if (i < axis) { diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index a8ea3ad760..f06fb612ee 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1097,21 +1097,40 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x, MetaTensor* xshape) { auto x_dims = x.dims(); int in_dims_size = x_dims.size(); + + if (in_dims_size == 0) { + PADDLE_ENFORCE_EQ( + start_axis == 0 || start_axis == -1, + true, + phi::errors::InvalidArgument("The start_axis should be 0 or -1 when " + "the input tensor is a 0D-Tensor")); + PADDLE_ENFORCE_EQ( + stop_axis == 0 || stop_axis == -1, + true, + phi::errors::InvalidArgument("The stop_axis should be 0 or -1 when the " + "input tensor is a 0D-Tensor")); + // this can ensure out shape {1} + start_axis = 0; + stop_axis = -1; + } + if (start_axis < 0) { start_axis = start_axis + in_dims_size; } if (stop_axis < 0) { stop_axis = stop_axis + in_dims_size; } - PADDLE_ENFORCE_GE( - stop_axis, - start_axis, - phi::errors::InvalidArgument("The stop_axis should be greater" - "than or equal to start_axis.")); + if (in_dims_size > 0) { + PADDLE_ENFORCE_GE( + stop_axis, + start_axis, + phi::errors::InvalidArgument("The stop_axis should be greater" + "than or equal to start_axis.")); + } int64_t outer = 1; std::vector out_shape; - out_shape.reserve(in_dims_size - stop_axis + start_axis); + out_shape.reserve(in_dims_size - stop_axis + start_axis + 1); for (int i = 0; i < start_axis; ++i) { out_shape.push_back(x_dims[i]); diff --git a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py index 06cef1d48c..df36af0f51 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py @@ -124,6 +124,20 @@ class TestFlattenOp_5(TestFlattenOp): } +class TestFlattenOp_6(TestFlattenOp): + def init_test_case(self): + self.in_shape = tuple() + self.start_axis = 0 + self.stop_axis = -1 + self.new_shape = (1,) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis, + } + + class TestFlattenOpSixDims(TestFlattenOp): def init_test_case(self): self.in_shape = (3, 2, 3, 2, 4, 4) @@ -156,7 +170,7 @@ class TestFlatten2OpError(unittest.TestCase): x_var = paddle.static.data( name="x", shape=image_shape, dtype='float32' ) - out = paddle.flatten(x_var, start_axis=2, stop_axis=1) + out = paddle.flatten(x_var, start_axis=3, stop_axis=1) self.assertRaises(ValueError, test_ValueError1) @@ -176,6 +190,22 @@ class TestFlatten2OpError(unittest.TestCase): self.assertRaises(ValueError, test_ValueError3) + def test_ValueError4(): + x_var = paddle.static.data( + name="x", shape=image_shape, dtype='float32' + ) + paddle.flatten(x_var, start_axis=2.0, stop_axis=10) + + self.assertRaises(ValueError, test_ValueError4) + + def test_ValueError5(): + x_var = paddle.static.data( + name="x", shape=image_shape, dtype='float32' + ) + paddle.flatten(x_var, start_axis=2, stop_axis=10.0) + + self.assertRaises(ValueError, test_ValueError5) + def test_type(): # dtype must be float32, float64, int8, int32, int64, uint8. x2 = ( @@ -295,5 +325,27 @@ class TestDygraphInplaceFlattenPython(unittest.TestCase): paddle.enable_static() +class TestFlatten0DTensorOpError(unittest.TestCase): + def test_errors(self): + image_shape = tuple() + x = np.random.uniform(-1.0, 1.0, []).astype('float32') + + def test_ValueError1(): + x_var = paddle.static.data( + name="x", shape=image_shape, dtype='float32' + ) + out = paddle.flatten(x_var, start_axis=10, stop_axis=0) + + self.assertRaises(ValueError, test_ValueError1) + + def test_ValueError2(): + x_var = paddle.static.data( + name="x", shape=image_shape, dtype='float32' + ) + out = paddle.flatten(x_var, start_axis=0, stop_axis=10) + + self.assertRaises(ValueError, test_ValueError2) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 546c0c48f9..710480bfd9 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -721,6 +721,19 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out.numpy()[3], 2) self.assertEqual(out.grad.shape, [5]) + def test_flatten(self): + x = paddle.rand([]) + x.stop_gradient = False + + start_axis = 0 + stop_axis = -1 + + out = paddle.flatten(x, start_axis=start_axis, stop_axis=stop_axis) + out.backward() + + self.assertEqual(out.shape, [1]) + self.assertEqual(x.grad.shape, []) + def test_scale(self): x = paddle.rand([]) x.stop_gradient = False @@ -1113,6 +1126,22 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[0].shape, (5,)) self.assertEqual(res[0][3], 2) + @prog_scope() + def test_flatten(self): + x = paddle.full([], 1, 'float32') + x.stop_gradient = False + + start_axis = 0 + stop_axis = -1 + + out = paddle.flatten(x, start_axis=start_axis, stop_axis=stop_axis) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, feed={}, fetch_list=[out]) + + self.assertEqual(res[0].shape, (1,)) + @prog_scope() def test_scale(self): x = paddle.rand([]) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index c0e0de0ac1..95d56bb902 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -521,6 +521,19 @@ class TestSundryAPI(unittest.TestCase): for i in range(3): self.assertEqual(out.numpy()[1][i], updates.numpy()[i]) + def test_flatten(self): + x = paddle.full([], 1, 'float32') + x.stop_gradient = False + + start_axis = 0 + stop_axis = -1 + + out = paddle.flatten(x, start_axis=start_axis, stop_axis=stop_axis) + out.backward() + + self.assertEqual(out.shape, [1]) + self.assertEqual(x.grad.shape, []) + def test_scale(self): x = paddle.rand([]) x.stop_gradient = False diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 70a6d848e1..e47ffd9e7d 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1550,28 +1550,38 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): ) x_dim = len(x.shape) - if ( - not (isinstance(start_axis, int)) - or (start_axis > x_dim - 1) - or start_axis < -x_dim - ): - raise ValueError( - "The start_axis should be a int, and in range [-rank(x), rank(x))" - ) - if ( - not (isinstance(stop_axis, int)) - or (stop_axis > x_dim - 1) - or stop_axis < -x_dim - ): - raise ValueError( - "The stop_axis should be a int, and in range [-rank(x), rank(x))" - ) - if start_axis < 0: - start_axis = start_axis + x_dim - if stop_axis < 0: - stop_axis = stop_axis + x_dim - if start_axis > stop_axis: - raise ValueError("The stop_axis should be larger than stat_axis") + if x_dim == 0: + if not (isinstance(start_axis, int)) or start_axis not in [0, -1]: + raise ValueError( + "The start_axis should be int, and should be 0 or -1 when the input tensor is a 0D-Tensor" + ) + if not (isinstance(stop_axis, int)) or stop_axis not in [0, -1]: + raise ValueError( + "The stop_axis should be int, and should be 0 or -1 when the input tensor is a 0D-Tensor" + ) + else: + if ( + not (isinstance(start_axis, int)) + or (start_axis > x_dim - 1) + or start_axis < -x_dim + ): + raise ValueError( + "The start_axis should be a int, and in range [-rank(x), rank(x))" + ) + if ( + not (isinstance(stop_axis, int)) + or (stop_axis > x_dim - 1) + or stop_axis < -x_dim + ): + raise ValueError( + "The stop_axis should be a int, and in range [-rank(x), rank(x))" + ) + if start_axis < 0: + start_axis = start_axis + x_dim + if stop_axis < 0: + stop_axis = stop_axis + x_dim + if start_axis > stop_axis: + raise ValueError("The stop_axis should be larger than stat_axis") if in_dygraph_mode(): return _C_ops.flatten(x, start_axis, stop_axis) -- GitLab