未验证 提交 b23801a2 编写于 作者: C Chen Weihang 提交者: GitHub

polish tensor set error messag, test=develop (#25113)

上级 542a226c
...@@ -246,12 +246,13 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj, ...@@ -246,12 +246,13 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
} else if (py::isinstance<py::array_t<bool>>(array)) { } else if (py::isinstance<py::array_t<bool>>(array)) {
SetTensorFromPyArrayT<bool, P>(self, array, place, zero_copy); SetTensorFromPyArrayT<bool, P>(self, array, place, zero_copy);
} else { } else {
// obj may be any type, obj.cast<py::array>() may be failed,
// then the array.dtype will be string of unknown meaning,
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Incompatible data type: tensor.set() supports bool, float16, " "Input object type error or incompatible array data type. "
"float32, " "tensor.set() supports array with bool, float16, float32, "
"float64, " "float64, int8, int16, int32, int64, uint8 or uint16, "
"int8, int16, int32, int64 and uint8, uint16, but got %s!", "please check your input or input array data type."));
array.dtype()));
} }
} }
......
...@@ -345,6 +345,22 @@ class TestTensor(unittest.TestCase): ...@@ -345,6 +345,22 @@ class TestTensor(unittest.TestCase):
self.assertEqual([2, 200, 300], tensor.shape()) self.assertEqual([2, 200, 300], tensor.shape())
self.assertTrue(numpy.array_equal(numpy.array(tensor), list_array)) self.assertTrue(numpy.array_equal(numpy.array(tensor), list_array))
def test_tensor_set_error(self):
scope = core.Scope()
var = scope.var("test_tensor")
place = core.CPUPlace()
tensor = var.get_tensor()
exception = None
try:
error_array = ["1", "2"]
tensor.set(error_array, place)
except core.EnforceNotMet as ex:
exception = ex
self.assertIsNotNone(exception)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册