From b23801a262e2327a82a42f77d774bcf037052fa0 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 19 Jun 2020 10:33:47 +0800 Subject: [PATCH] polish tensor set error messag, test=develop (#25113) --- paddle/fluid/pybind/tensor_py.h | 11 ++++++----- .../paddle/fluid/tests/unittests/test_tensor.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 582fc979d5f..ba79c4b4437 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -246,12 +246,13 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj, } else if (py::isinstance>(array)) { SetTensorFromPyArrayT(self, array, place, zero_copy); } else { + // obj may be any type, obj.cast() may be failed, + // then the array.dtype will be string of unknown meaning, PADDLE_THROW(platform::errors::InvalidArgument( - "Incompatible data type: tensor.set() supports bool, float16, " - "float32, " - "float64, " - "int8, int16, int32, int64 and uint8, uint16, but got %s!", - array.dtype())); + "Input object type error or incompatible array data type. " + "tensor.set() supports array with bool, float16, float32, " + "float64, int8, int16, int32, int64, uint8 or uint16, " + "please check your input or input array data type.")); } } diff --git a/python/paddle/fluid/tests/unittests/test_tensor.py b/python/paddle/fluid/tests/unittests/test_tensor.py index 24be25fda2e..03dffe4e5a2 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_tensor.py @@ -345,6 +345,22 @@ class TestTensor(unittest.TestCase): self.assertEqual([2, 200, 300], tensor.shape()) self.assertTrue(numpy.array_equal(numpy.array(tensor), list_array)) + def test_tensor_set_error(self): + scope = core.Scope() + var = scope.var("test_tensor") + place = core.CPUPlace() + + tensor = var.get_tensor() + + exception = None + try: + error_array = ["1", "2"] + tensor.set(error_array, place) + except core.EnforceNotMet as ex: + exception = ex + + self.assertIsNotNone(exception) + if __name__ == '__main__': unittest.main() -- GitLab