diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 9f97556e200403e7dca76bad787551cfc44432b9..629419c2a4073041feb7ce7da4cac68c56771249 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -92,7 +92,8 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) { // sometimes users provide PyLong or numpy.int64 but attr is float if (PyFloat_Check(*obj) || PyLong_Check(*obj) || PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype) || // NOLINT - PyObject_IsInstance(*obj, (PyObject*)p_tensor_type)) { // NOLINT + (PyObject_IsInstance(*obj, (PyObject*)p_tensor_type) && // NOLINT + (((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT return true; } if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name) // NOLINT