未验证 提交 765fbb59 编写于 作者: Z zyfncg 提交者: GitHub

[cherry-pick] Fix bug of building InferMetaContext (#42211) (#42399)

* fix bug of building InferMetaContext (#42211)

* add unitest
上级 df39d157
...@@ -558,10 +558,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -558,10 +558,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
} }
if (num_ele <= 0) { if (num_ele <= 0) {
PADDLE_THROW(platform::errors::Unimplemented( num_ele = tensor_dims.size();
"Invalid number for construct phi::IntArray, expected "
"number > 0, but actually is %d. ",
num_ele));
} }
} else { } else {
......
...@@ -232,28 +232,33 @@ class TestEmptyAPI(unittest.TestCase): ...@@ -232,28 +232,33 @@ class TestEmptyAPI(unittest.TestCase):
name="shape_tensor_int32", shape=[2], dtype="int32") name="shape_tensor_int32", shape=[2], dtype="int32")
shape_tensor_int64 = fluid.data( shape_tensor_int64 = fluid.data(
name="shape_tensor_int64", shape=[2], dtype="int64") name="shape_tensor_int64", shape=[2], dtype="int64")
shape_tensor_unknown = fluid.data(
name="shape_tensor_unknown", shape=[-1], dtype="int64")
out_1 = paddle.empty(shape=[200, 3], dtype=dtype) out_1 = paddle.empty(shape=[200, 3], dtype=dtype)
out_2 = paddle.empty(shape=shape_tensor_int32, dtype=dtype) out_2 = paddle.empty(shape=shape_tensor_int32, dtype=dtype)
out_3 = paddle.empty(shape=shape_tensor_int64, dtype=dtype) out_3 = paddle.empty(shape=shape_tensor_int64, dtype=dtype)
out_4 = paddle.empty(shape=[200, positive_2_int32], dtype=dtype) out_4 = paddle.empty(shape=[200, positive_2_int32], dtype=dtype)
out_5 = paddle.empty(shape=[200, positive_2_int64], dtype=dtype) out_5 = paddle.empty(shape=[200, positive_2_int64], dtype=dtype)
out_6 = paddle.empty(shape=shape_tensor_unknown, dtype=dtype)
place = paddle.CPUPlace() place = paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
res_1, res_2, res_3, res_4, res_5 = exe.run( res_1, res_2, res_3, res_4, res_5, res_6 = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed={ feed={
"shape_tensor_int32": np.array([200, 3]).astype("int32"), "shape_tensor_int32": np.array([200, 3]).astype("int32"),
"shape_tensor_int64": np.array([200, 3]).astype("int64"), "shape_tensor_int64": np.array([200, 3]).astype("int64"),
"shape_tensor_unknown": np.array([200, 3]).astype("int64"),
}, },
fetch_list=[out_1, out_2, out_3, out_4, out_5]) fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6])
self.__check_out__(res_1, dtype) self.__check_out__(res_1, dtype)
self.__check_out__(res_2, dtype) self.__check_out__(res_2, dtype)
self.__check_out__(res_3, dtype) self.__check_out__(res_3, dtype)
self.__check_out__(res_4, dtype) self.__check_out__(res_4, dtype)
self.__check_out__(res_5, dtype) self.__check_out__(res_5, dtype)
self.__check_out__(res_6, dtype)
class TestEmptyError(unittest.TestCase): class TestEmptyError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册