未验证 提交 d8b69124 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Make CreateInferMeta more robust (#42871)

上级 723c4ae7
......@@ -24,6 +24,20 @@ void AssignValueInferMeta(const std::vector<int>& shape,
}
void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) {
if (!shape.FromTensor()) {
const auto& data = shape.GetData();
for (size_t i = 0; i < data.size(); ++i) {
PADDLE_ENFORCE_GE(
data[i],
0,
phi::errors::InvalidArgument(
"Each value of attribute 'shape' is expected to be no less "
"than 0. But recieved: shape[%u] = %d; shape = [%s].",
i,
data[i],
phi::make_ddim(data)));
}
}
CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out);
}
......
......@@ -17,6 +17,7 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.compat as cpt
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
......@@ -96,9 +97,19 @@ class ApiZerosError(unittest.TestCase):
self.assertRaises(TypeError, test_error2)
def test_shape_errors(self):
with fluid.dygraph.guard():
try:
shape = [-1, 5]
out = paddle.zeros(shape)
except Exception as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("expected to be no less than 0") > 0
def test_eager(self):
with _test_eager_guard():
self.test_errors()
self.test_shape_errors()
if (__name__ == '__main__'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册