未验证 提交 0093aaa6 编写于 作者: J jiangcheng 提交者: GitHub

【Zero-Dim】Flatten support 0d tensor (#49361)

* flatten op support 0D-tensor

* add test in zero dim py

* fix shape should be list

* short code for ci-coverage

* add backward test

* simple code for ci coverage

* add axis check

* add 0D-tensor test in test_flatten_contiguous_range_op.py

* add axis error test for Coverage CI

* add more test for CI-Coverage

* add more test for CI-Coverage
上级 215c7ae7
...@@ -41,11 +41,13 @@ class FlattenOp : public framework::OperatorWithKernel { ...@@ -41,11 +41,13 @@ class FlattenOp : public framework::OperatorWithKernel {
0, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The axis should be greater than or equal to 0.")); "The axis should be greater than or equal to 0."));
PADDLE_ENFORCE_LE( if (in_dims.size() > 0) {
axis, PADDLE_ENFORCE_LE(
in_dims.size(), axis,
platform::errors::InvalidArgument( in_dims.size(),
"The axis should be less than or equal to input tensor's rank.")); platform::errors::InvalidArgument(
"The axis should be less than or equal to input tensor's rank."));
}
const auto &out_dims = GetOutputShape(axis, in_dims); const auto &out_dims = GetOutputShape(axis, in_dims);
ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
...@@ -58,6 +60,10 @@ class FlattenOp : public framework::OperatorWithKernel { ...@@ -58,6 +60,10 @@ class FlattenOp : public framework::OperatorWithKernel {
static std::vector<int32_t> GetOutputShape(const int axis, static std::vector<int32_t> GetOutputShape(const int axis,
const framework::DDim &in_dims) { const framework::DDim &in_dims) {
if (in_dims.size() == 0) {
return {1};
}
int64_t outer = 1, inner = 1; int64_t outer = 1, inner = 1;
for (int i = 0; i < in_dims.size(); ++i) { for (int i = 0; i < in_dims.size(); ++i) {
if (i < axis) { if (i < axis) {
......
...@@ -49,6 +49,10 @@ class FlattenKernel : public framework::OpKernel<T> { ...@@ -49,6 +49,10 @@ class FlattenKernel : public framework::OpKernel<T> {
static std::vector<int32_t> GetOutputShape(const int axis, static std::vector<int32_t> GetOutputShape(const int axis,
const framework::DDim &in_dims) { const framework::DDim &in_dims) {
if (in_dims.size() == 0) {
return {1};
}
int64_t outer = 1, inner = 1; int64_t outer = 1, inner = 1;
for (int i = 0; i < in_dims.size(); ++i) { for (int i = 0; i < in_dims.size(); ++i) {
if (i < axis) { if (i < axis) {
......
...@@ -1097,21 +1097,40 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x, ...@@ -1097,21 +1097,40 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x,
MetaTensor* xshape) { MetaTensor* xshape) {
auto x_dims = x.dims(); auto x_dims = x.dims();
int in_dims_size = x_dims.size(); int in_dims_size = x_dims.size();
if (in_dims_size == 0) {
PADDLE_ENFORCE_EQ(
start_axis == 0 || start_axis == -1,
true,
phi::errors::InvalidArgument("The start_axis should be 0 or -1 when "
"the input tensor is a 0D-Tensor"));
PADDLE_ENFORCE_EQ(
stop_axis == 0 || stop_axis == -1,
true,
phi::errors::InvalidArgument("The stop_axis should be 0 or -1 when the "
"input tensor is a 0D-Tensor"));
// this can ensure out shape {1}
start_axis = 0;
stop_axis = -1;
}
if (start_axis < 0) { if (start_axis < 0) {
start_axis = start_axis + in_dims_size; start_axis = start_axis + in_dims_size;
} }
if (stop_axis < 0) { if (stop_axis < 0) {
stop_axis = stop_axis + in_dims_size; stop_axis = stop_axis + in_dims_size;
} }
PADDLE_ENFORCE_GE( if (in_dims_size > 0) {
stop_axis, PADDLE_ENFORCE_GE(
start_axis, stop_axis,
phi::errors::InvalidArgument("The stop_axis should be greater" start_axis,
"than or equal to start_axis.")); phi::errors::InvalidArgument("The stop_axis should be greater"
"than or equal to start_axis."));
}
int64_t outer = 1; int64_t outer = 1;
std::vector<int32_t> out_shape; std::vector<int32_t> out_shape;
out_shape.reserve(in_dims_size - stop_axis + start_axis); out_shape.reserve(in_dims_size - stop_axis + start_axis + 1);
for (int i = 0; i < start_axis; ++i) { for (int i = 0; i < start_axis; ++i) {
out_shape.push_back(x_dims[i]); out_shape.push_back(x_dims[i]);
......
...@@ -124,6 +124,20 @@ class TestFlattenOp_5(TestFlattenOp): ...@@ -124,6 +124,20 @@ class TestFlattenOp_5(TestFlattenOp):
} }
class TestFlattenOp_6(TestFlattenOp):
def init_test_case(self):
self.in_shape = tuple()
self.start_axis = 0
self.stop_axis = -1
self.new_shape = (1,)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis,
}
class TestFlattenOpSixDims(TestFlattenOp): class TestFlattenOpSixDims(TestFlattenOp):
def init_test_case(self): def init_test_case(self):
self.in_shape = (3, 2, 3, 2, 4, 4) self.in_shape = (3, 2, 3, 2, 4, 4)
...@@ -156,7 +170,7 @@ class TestFlatten2OpError(unittest.TestCase): ...@@ -156,7 +170,7 @@ class TestFlatten2OpError(unittest.TestCase):
x_var = paddle.static.data( x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32' name="x", shape=image_shape, dtype='float32'
) )
out = paddle.flatten(x_var, start_axis=2, stop_axis=1) out = paddle.flatten(x_var, start_axis=3, stop_axis=1)
self.assertRaises(ValueError, test_ValueError1) self.assertRaises(ValueError, test_ValueError1)
...@@ -176,6 +190,22 @@ class TestFlatten2OpError(unittest.TestCase): ...@@ -176,6 +190,22 @@ class TestFlatten2OpError(unittest.TestCase):
self.assertRaises(ValueError, test_ValueError3) self.assertRaises(ValueError, test_ValueError3)
def test_ValueError4():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32'
)
paddle.flatten(x_var, start_axis=2.0, stop_axis=10)
self.assertRaises(ValueError, test_ValueError4)
def test_ValueError5():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32'
)
paddle.flatten(x_var, start_axis=2, stop_axis=10.0)
self.assertRaises(ValueError, test_ValueError5)
def test_type(): def test_type():
# dtype must be float32, float64, int8, int32, int64, uint8. # dtype must be float32, float64, int8, int32, int64, uint8.
x2 = ( x2 = (
...@@ -295,5 +325,27 @@ class TestDygraphInplaceFlattenPython(unittest.TestCase): ...@@ -295,5 +325,27 @@ class TestDygraphInplaceFlattenPython(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
class TestFlatten0DTensorOpError(unittest.TestCase):
def test_errors(self):
image_shape = tuple()
x = np.random.uniform(-1.0, 1.0, []).astype('float32')
def test_ValueError1():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32'
)
out = paddle.flatten(x_var, start_axis=10, stop_axis=0)
self.assertRaises(ValueError, test_ValueError1)
def test_ValueError2():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32'
)
out = paddle.flatten(x_var, start_axis=0, stop_axis=10)
self.assertRaises(ValueError, test_ValueError2)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -721,6 +721,19 @@ class TestSundryAPI(unittest.TestCase): ...@@ -721,6 +721,19 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.numpy()[3], 2) self.assertEqual(out.numpy()[3], 2)
self.assertEqual(out.grad.shape, [5]) self.assertEqual(out.grad.shape, [5])
def test_flatten(self):
x = paddle.rand([])
x.stop_gradient = False
start_axis = 0
stop_axis = -1
out = paddle.flatten(x, start_axis=start_axis, stop_axis=stop_axis)
out.backward()
self.assertEqual(out.shape, [1])
self.assertEqual(x.grad.shape, [])
def test_scale(self): def test_scale(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
...@@ -1113,6 +1126,22 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1113,6 +1126,22 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[0].shape, (5,)) self.assertEqual(res[0].shape, (5,))
self.assertEqual(res[0][3], 2) self.assertEqual(res[0][3], 2)
@prog_scope()
def test_flatten(self):
x = paddle.full([], 1, 'float32')
x.stop_gradient = False
start_axis = 0
stop_axis = -1
out = paddle.flatten(x, start_axis=start_axis, stop_axis=stop_axis)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out])
self.assertEqual(res[0].shape, (1,))
@prog_scope() @prog_scope()
def test_scale(self): def test_scale(self):
x = paddle.rand([]) x = paddle.rand([])
......
...@@ -521,6 +521,19 @@ class TestSundryAPI(unittest.TestCase): ...@@ -521,6 +521,19 @@ class TestSundryAPI(unittest.TestCase):
for i in range(3): for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i]) self.assertEqual(out.numpy()[1][i], updates.numpy()[i])
def test_flatten(self):
x = paddle.full([], 1, 'float32')
x.stop_gradient = False
start_axis = 0
stop_axis = -1
out = paddle.flatten(x, start_axis=start_axis, stop_axis=stop_axis)
out.backward()
self.assertEqual(out.shape, [1])
self.assertEqual(x.grad.shape, [])
def test_scale(self): def test_scale(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
......
...@@ -1550,28 +1550,38 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): ...@@ -1550,28 +1550,38 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
) )
x_dim = len(x.shape) x_dim = len(x.shape)
if ( if x_dim == 0:
not (isinstance(start_axis, int)) if not (isinstance(start_axis, int)) or start_axis not in [0, -1]:
or (start_axis > x_dim - 1) raise ValueError(
or start_axis < -x_dim "The start_axis should be int, and should be 0 or -1 when the input tensor is a 0D-Tensor"
): )
raise ValueError( if not (isinstance(stop_axis, int)) or stop_axis not in [0, -1]:
"The start_axis should be a int, and in range [-rank(x), rank(x))" raise ValueError(
) "The stop_axis should be int, and should be 0 or -1 when the input tensor is a 0D-Tensor"
if ( )
not (isinstance(stop_axis, int)) else:
or (stop_axis > x_dim - 1) if (
or stop_axis < -x_dim not (isinstance(start_axis, int))
): or (start_axis > x_dim - 1)
raise ValueError( or start_axis < -x_dim
"The stop_axis should be a int, and in range [-rank(x), rank(x))" ):
) raise ValueError(
if start_axis < 0: "The start_axis should be a int, and in range [-rank(x), rank(x))"
start_axis = start_axis + x_dim )
if stop_axis < 0: if (
stop_axis = stop_axis + x_dim not (isinstance(stop_axis, int))
if start_axis > stop_axis: or (stop_axis > x_dim - 1)
raise ValueError("The stop_axis should be larger than stat_axis") or stop_axis < -x_dim
):
raise ValueError(
"The stop_axis should be a int, and in range [-rank(x), rank(x))"
)
if start_axis < 0:
start_axis = start_axis + x_dim
if stop_axis < 0:
stop_axis = stop_axis + x_dim
if start_axis > stop_axis:
raise ValueError("The stop_axis should be larger than stat_axis")
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.flatten(x, start_axis, stop_axis) return _C_ops.flatten(x, start_axis, stop_axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册