diff --git a/paddle/fluid/pybind/op_function.h b/paddle/fluid/pybind/op_function.h index 70b321f658cd2cf1bd43cd6440bf83e1f4dab140..1e20ac958b9bbbc6606b4db0b12b90708526f53a 100644 --- a/paddle/fluid/pybind/op_function.h +++ b/paddle/fluid/pybind/op_function.h @@ -36,9 +36,15 @@ namespace pybind { static inline std::shared_ptr CastPyHandleToVarBase( const std::string& op_type, const std::string& arg_name, int arg_idx, - const py::handle& handle) { + const py::handle& handle, bool dispensable = false) { PyObject* py_obj = handle.ptr(); // get underlying PyObject if (!py_obj || py_obj == Py_None) { + if (!dispensable) { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be Tensor, but got " + "%s", + op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name)); + } return nullptr; } try { @@ -54,9 +60,15 @@ static inline std::shared_ptr CastPyHandleToVarBase( static inline std::vector> CastPyHandleToVarBaseList(const std::string& op_type, const std::string& arg_name, int arg_idx, - const py::handle& handle) { + const py::handle& handle, bool dispensable = false) { PyObject* py_obj = handle.ptr(); // get underlying PyObject if (!py_obj || py_obj == Py_None) { + if (!dispensable) { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be Tensor, but got " + "%s", + op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name)); + } return {}; } std::vector> result; diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 10914cf0ab7ba2292e59847bafdff1ce23a730e1..0f5ce8415594621fdbfa1fb7a60fa0ab469a18e5 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -166,10 +166,10 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr)"; const char* OUT_VAR_LIST_TYPE = R"(std::vector>)"; const char* CAST_VAR_TEMPLATE = R"( - auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s);)"; + auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s, %s);)"; const char* CAST_VAR_LIST_TEMPLATE = R"( - auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s);)"; + auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s, %s);)"; const char* ARG_TEMPLATE = R"(const %s& %s)"; @@ -263,9 +263,10 @@ GenerateOpFunctions(const std::string& module_name) { input_args_num++; const auto in_cast_type = input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; + auto dispensable = input.dispensable() ? "true" : "false"; ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name, - arg_idx++, TempName(in_name)); + arg_idx++, TempName(in_name), dispensable); if (input.dispensable()) { const auto in_template = input.duplicable()