From d8b691242d02b4117eb4b06985cd0553946bac12 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Fri, 20 May 2022 16:10:07 +0800 Subject: [PATCH] [Eager] Make CreateInferMeta more robust (#42871) --- paddle/phi/infermeta/nullary.cc | 14 ++++++++++++++ .../paddle/fluid/tests/unittests/test_zeros_op.py | 11 +++++++++++ 2 files changed, 25 insertions(+) diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index c3ded621718..069359bae92 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -24,6 +24,20 @@ void AssignValueInferMeta(const std::vector& shape, } void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) { + if (!shape.FromTensor()) { + const auto& data = shape.GetData(); + for (size_t i = 0; i < data.size(); ++i) { + PADDLE_ENFORCE_GE( + data[i], + 0, + phi::errors::InvalidArgument( + "Each value of attribute 'shape' is expected to be no less " + "than 0. But recieved: shape[%u] = %d; shape = [%s].", + i, + data[i], + phi::make_ddim(data))); + } + } CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out); } diff --git a/python/paddle/fluid/tests/unittests/test_zeros_op.py b/python/paddle/fluid/tests/unittests/test_zeros_op.py index 449f95aac29..01d7107cfae 100644 --- a/python/paddle/fluid/tests/unittests/test_zeros_op.py +++ b/python/paddle/fluid/tests/unittests/test_zeros_op.py @@ -17,6 +17,7 @@ import unittest import numpy as np from op_test import OpTest import paddle +import paddle.compat as cpt import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid @@ -96,9 +97,19 @@ class ApiZerosError(unittest.TestCase): self.assertRaises(TypeError, test_error2) + def test_shape_errors(self): + with fluid.dygraph.guard(): + try: + shape = [-1, 5] + out = paddle.zeros(shape) + except Exception as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("expected to be no less than 0") > 0 + def test_eager(self): with _test_eager_guard(): self.test_errors() + self.test_shape_errors() if (__name__ == '__main__'): -- GitLab