diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 089698517a0c94e7f4ba6b50d90045ad03b3ddc8..1abfe5fbd0c5a75f08fc7dbfe1b5926565c4be14 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -558,10 +558,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } if (num_ele <= 0) { - PADDLE_THROW(platform::errors::Unimplemented( - "Invalid number for construct phi::IntArray, expected " - "number > 0, but actually is %d. ", - num_ele)); + num_ele = tensor_dims.size(); } } else { diff --git a/python/paddle/fluid/tests/unittests/test_empty_op.py b/python/paddle/fluid/tests/unittests/test_empty_op.py index b8ff66a910ece4584be53baf1d194161fec0965a..371c59a1b8cce36cc4ecec73dd938250b1a0f3b4 100644 --- a/python/paddle/fluid/tests/unittests/test_empty_op.py +++ b/python/paddle/fluid/tests/unittests/test_empty_op.py @@ -232,28 +232,33 @@ class TestEmptyAPI(unittest.TestCase): name="shape_tensor_int32", shape=[2], dtype="int32") shape_tensor_int64 = fluid.data( name="shape_tensor_int64", shape=[2], dtype="int64") + shape_tensor_unknown = fluid.data( + name="shape_tensor_unknown", shape=[-1], dtype="int64") out_1 = paddle.empty(shape=[200, 3], dtype=dtype) out_2 = paddle.empty(shape=shape_tensor_int32, dtype=dtype) out_3 = paddle.empty(shape=shape_tensor_int64, dtype=dtype) out_4 = paddle.empty(shape=[200, positive_2_int32], dtype=dtype) out_5 = paddle.empty(shape=[200, positive_2_int64], dtype=dtype) + out_6 = paddle.empty(shape=shape_tensor_unknown, dtype=dtype) place = paddle.CPUPlace() exe = paddle.static.Executor(place) - res_1, res_2, res_3, res_4, res_5 = exe.run( + res_1, res_2, res_3, res_4, res_5, res_6 = exe.run( fluid.default_main_program(), feed={ "shape_tensor_int32": np.array([200, 3]).astype("int32"), "shape_tensor_int64": np.array([200, 3]).astype("int64"), + "shape_tensor_unknown": np.array([200, 3]).astype("int64"), }, - fetch_list=[out_1, out_2, out_3, out_4, out_5]) + fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6]) self.__check_out__(res_1, dtype) self.__check_out__(res_2, dtype) self.__check_out__(res_3, dtype) self.__check_out__(res_4, dtype) self.__check_out__(res_5, dtype) + self.__check_out__(res_6, dtype) class TestEmptyError(unittest.TestCase):