未验证 提交 e4a16f3f 编写于 作者: R RedContritio 提交者: GitHub

add tensor numel check for float (#50415)

上级 23e544d1
......@@ -92,7 +92,8 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
// sometimes users provide PyLong or numpy.int64 but attr is float
if (PyFloat_Check(*obj) || PyLong_Check(*obj) ||
PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype) || // NOLINT
PyObject_IsInstance(*obj, (PyObject*)p_tensor_type)) { // NOLINT
(PyObject_IsInstance(*obj, (PyObject*)p_tensor_type) && // NOLINT
(((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT
return true;
}
if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name) // NOLINT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册