未验证 提交 0093aaa6 编写于 作者: J jiangcheng 提交者: GitHub

【Zero-Dim】Flatten support 0d tensor (#49361)

* flatten op support 0D-tensor

* add test in zero dim py

* fix shape should be list

* short code for ci-coverage

* add backward test

* simple code for ci coverage

* add axis check

* add 0D-tensor test in test_flatten_contiguous_range_op.py

* add axis error test for Coverage CI

* add more test for CI-Coverage

* add more test for CI-Coverage
上级 215c7ae7
......@@ -41,11 +41,13 @@ class FlattenOp : public framework::OperatorWithKernel {
0,
platform::errors::InvalidArgument(
"The axis should be greater than or equal to 0."));
PADDLE_ENFORCE_LE(
axis,
in_dims.size(),
platform::errors::InvalidArgument(
"The axis should be less than or equal to input tensor's rank."));
if (in_dims.size() > 0) {
PADDLE_ENFORCE_LE(
axis,
in_dims.size(),
platform::errors::InvalidArgument(
"The axis should be less than or equal to input tensor's rank."));
}
const auto &out_dims = GetOutputShape(axis, in_dims);
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
......@@ -58,6 +60,10 @@ class FlattenOp : public framework::OperatorWithKernel {
static std::vector<int32_t> GetOutputShape(const int axis,
const framework::DDim &in_dims) {
if (in_dims.size() == 0) {
return {1};
}
int64_t outer = 1, inner = 1;
for (int i = 0; i < in_dims.size(); ++i) {
if (i < axis) {
......
......@@ -49,6 +49,10 @@ class FlattenKernel : public framework::OpKernel<T> {
static std::vector<int32_t> GetOutputShape(const int axis,
const framework::DDim &in_dims) {
if (in_dims.size() == 0) {
return {1};
}
int64_t outer = 1, inner = 1;
for (int i = 0; i < in_dims.size(); ++i) {
if (i < axis) {
......
......@@ -1097,21 +1097,40 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x,
MetaTensor* xshape) {
auto x_dims = x.dims();
int in_dims_size = x_dims.size();
if (in_dims_size == 0) {
PADDLE_ENFORCE_EQ(
start_axis == 0 || start_axis == -1,
true,
phi::errors::InvalidArgument("The start_axis should be 0 or -1 when "
"the input tensor is a 0D-Tensor"));
PADDLE_ENFORCE_EQ(
stop_axis == 0 || stop_axis == -1,
true,
phi::errors::InvalidArgument("The stop_axis should be 0 or -1 when the "
"input tensor is a 0D-Tensor"));
// this can ensure out shape {1}
start_axis = 0;
stop_axis = -1;
}
if (start_axis < 0) {
start_axis = start_axis + in_dims_size;
}
if (stop_axis < 0) {
stop_axis = stop_axis + in_dims_size;
}
PADDLE_ENFORCE_GE(
stop_axis,
start_axis,
phi::errors::InvalidArgument("The stop_axis should be greater"
"than or equal to start_axis."));
if (in_dims_size > 0) {
PADDLE_ENFORCE_GE(
stop_axis,
start_axis,
phi::errors::InvalidArgument("The stop_axis should be greater"
"than or equal to start_axis."));
}
int64_t outer = 1;
std::vector<int32_t> out_shape;
out_shape.reserve(in_dims_size - stop_axis + start_axis);
out_shape.reserve(in_dims_size - stop_axis + start_axis + 1);
for (int i = 0; i < start_axis; ++i) {
out_shape.push_back(x_dims[i]);
......
......@@ -124,6 +124,20 @@ class TestFlattenOp_5(TestFlattenOp):
}
class TestFlattenOp_6(TestFlattenOp):
def init_test_case(self):
self.in_shape = tuple()
self.start_axis = 0
self.stop_axis = -1
self.new_shape = (1,)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis,
}
class TestFlattenOpSixDims(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 3, 2, 4, 4)
......@@ -156,7 +170,7 @@ class TestFlatten2OpError(unittest.TestCase):
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32'
)
out = paddle.flatten(x_var, start_axis=2, stop_axis=1)
out = paddle.flatten(x_var, start_axis=3, stop_axis=1)
self.assertRaises(ValueError, test_ValueError1)
......@@ -176,6 +190,22 @@ class TestFlatten2OpError(unittest.TestCase):
self.assertRaises(ValueError, test_ValueError3)
def test_ValueError4():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32'
)
paddle.flatten(x_var, start_axis=2.0, stop_axis=10)
self.assertRaises(ValueError, test_ValueError4)
def test_ValueError5():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32'
)
paddle.flatten(x_var, start_axis=2, stop_axis=10.0)
self.assertRaises(ValueError, test_ValueError5)
def test_type():
# dtype must be float32, float64, int8, int32, int64, uint8.
x2 = (
......@@ -295,5 +325,27 @@ class TestDygraphInplaceFlattenPython(unittest.TestCase):
paddle.enable_static()
class TestFlatten0DTensorOpError(unittest.TestCase):
def test_errors(self):
image_shape = tuple()
x = np.random.uniform(-1.0, 1.0, []).astype('float32')
def test_ValueError1():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32'
)
out = paddle.flatten(x_var, start_axis=10, stop_axis=0)
self.assertRaises(ValueError, test_ValueError1)
def test_ValueError2():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32'
)
out = paddle.flatten(x_var, start_axis=0, stop_axis=10)
self.assertRaises(ValueError, test_ValueError2)
if __name__ == "__main__":
unittest.main()
......@@ -721,6 +721,19 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.numpy()[3], 2)
self.assertEqual(out.grad.shape, [5])
def test_flatten(self):
x = paddle.rand([])
x.stop_gradient = False
start_axis = 0
stop_axis = -1
out = paddle.flatten(x, start_axis=start_axis, stop_axis=stop_axis)
out.backward()
self.assertEqual(out.shape, [1])
self.assertEqual(x.grad.shape, [])
def test_scale(self):
x = paddle.rand([])
x.stop_gradient = False
......@@ -1113,6 +1126,22 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[0].shape, (5,))
self.assertEqual(res[0][3], 2)
@prog_scope()
def test_flatten(self):
x = paddle.full([], 1, 'float32')
x.stop_gradient = False
start_axis = 0
stop_axis = -1
out = paddle.flatten(x, start_axis=start_axis, stop_axis=stop_axis)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out])
self.assertEqual(res[0].shape, (1,))
@prog_scope()
def test_scale(self):
x = paddle.rand([])
......
......@@ -521,6 +521,19 @@ class TestSundryAPI(unittest.TestCase):
for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])
def test_flatten(self):
x = paddle.full([], 1, 'float32')
x.stop_gradient = False
start_axis = 0
stop_axis = -1
out = paddle.flatten(x, start_axis=start_axis, stop_axis=stop_axis)
out.backward()
self.assertEqual(out.shape, [1])
self.assertEqual(x.grad.shape, [])
def test_scale(self):
x = paddle.rand([])
x.stop_gradient = False
......
......@@ -1550,28 +1550,38 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
)
x_dim = len(x.shape)
if (
not (isinstance(start_axis, int))
or (start_axis > x_dim - 1)
or start_axis < -x_dim
):
raise ValueError(
"The start_axis should be a int, and in range [-rank(x), rank(x))"
)
if (
not (isinstance(stop_axis, int))
or (stop_axis > x_dim - 1)
or stop_axis < -x_dim
):
raise ValueError(
"The stop_axis should be a int, and in range [-rank(x), rank(x))"
)
if start_axis < 0:
start_axis = start_axis + x_dim
if stop_axis < 0:
stop_axis = stop_axis + x_dim
if start_axis > stop_axis:
raise ValueError("The stop_axis should be larger than stat_axis")
if x_dim == 0:
if not (isinstance(start_axis, int)) or start_axis not in [0, -1]:
raise ValueError(
"The start_axis should be int, and should be 0 or -1 when the input tensor is a 0D-Tensor"
)
if not (isinstance(stop_axis, int)) or stop_axis not in [0, -1]:
raise ValueError(
"The stop_axis should be int, and should be 0 or -1 when the input tensor is a 0D-Tensor"
)
else:
if (
not (isinstance(start_axis, int))
or (start_axis > x_dim - 1)
or start_axis < -x_dim
):
raise ValueError(
"The start_axis should be a int, and in range [-rank(x), rank(x))"
)
if (
not (isinstance(stop_axis, int))
or (stop_axis > x_dim - 1)
or stop_axis < -x_dim
):
raise ValueError(
"The stop_axis should be a int, and in range [-rank(x), rank(x))"
)
if start_axis < 0:
start_axis = start_axis + x_dim
if stop_axis < 0:
stop_axis = stop_axis + x_dim
if start_axis > stop_axis:
raise ValueError("The stop_axis should be larger than stat_axis")
if in_dygraph_mode():
return _C_ops.flatten(x, start_axis, stop_axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册