未验证 提交 d8b69124 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Make CreateInferMeta more robust (#42871)

上级 723c4ae7
...@@ -24,6 +24,20 @@ void AssignValueInferMeta(const std::vector<int>& shape, ...@@ -24,6 +24,20 @@ void AssignValueInferMeta(const std::vector<int>& shape,
} }
void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) { void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) {
if (!shape.FromTensor()) {
const auto& data = shape.GetData();
for (size_t i = 0; i < data.size(); ++i) {
PADDLE_ENFORCE_GE(
data[i],
0,
phi::errors::InvalidArgument(
"Each value of attribute 'shape' is expected to be no less "
"than 0. But recieved: shape[%u] = %d; shape = [%s].",
i,
data[i],
phi::make_ddim(data)));
}
}
CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out); CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out);
} }
......
...@@ -17,6 +17,7 @@ import unittest ...@@ -17,6 +17,7 @@ import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.compat as cpt
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -96,9 +97,19 @@ class ApiZerosError(unittest.TestCase): ...@@ -96,9 +97,19 @@ class ApiZerosError(unittest.TestCase):
self.assertRaises(TypeError, test_error2) self.assertRaises(TypeError, test_error2)
def test_shape_errors(self):
with fluid.dygraph.guard():
try:
shape = [-1, 5]
out = paddle.zeros(shape)
except Exception as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("expected to be no less than 0") > 0
def test_eager(self): def test_eager(self):
with _test_eager_guard(): with _test_eager_guard():
self.test_errors() self.test_errors()
self.test_shape_errors()
if (__name__ == '__main__'): if (__name__ == '__main__'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册