From 3d09929b1f28b978a5f34dc6139546c4d7def323 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 18 Nov 2020 22:05:41 +0800 Subject: [PATCH] Add check for non-dispensable input (#28666) * Add check for non-dispensable input * fix typo --- paddle/fluid/pybind/op_function.h | 16 ++++++++++++++-- paddle/fluid/pybind/op_function_generator.cc | 7 ++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pybind/op_function.h b/paddle/fluid/pybind/op_function.h index 70b321f658..1e20ac958b 100644 --- a/paddle/fluid/pybind/op_function.h +++ b/paddle/fluid/pybind/op_function.h @@ -36,9 +36,15 @@ namespace pybind { static inline std::shared_ptr CastPyHandleToVarBase( const std::string& op_type, const std::string& arg_name, int arg_idx, - const py::handle& handle) { + const py::handle& handle, bool dispensable = false) { PyObject* py_obj = handle.ptr(); // get underlying PyObject if (!py_obj || py_obj == Py_None) { + if (!dispensable) { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be Tensor, but got " + "%s", + op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name)); + } return nullptr; } try { @@ -54,9 +60,15 @@ static inline std::shared_ptr CastPyHandleToVarBase( static inline std::vector> CastPyHandleToVarBaseList(const std::string& op_type, const std::string& arg_name, int arg_idx, - const py::handle& handle) { + const py::handle& handle, bool dispensable = false) { PyObject* py_obj = handle.ptr(); // get underlying PyObject if (!py_obj || py_obj == Py_None) { + if (!dispensable) { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be Tensor, but got " + "%s", + op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name)); + } return {}; } std::vector> result; diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 10914cf0ab..0f5ce84155 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -166,10 +166,10 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr)"; const char* OUT_VAR_LIST_TYPE = R"(std::vector>)"; const char* CAST_VAR_TEMPLATE = R"( - auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s);)"; + auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s, %s);)"; const char* CAST_VAR_LIST_TEMPLATE = R"( - auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s);)"; + auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s, %s);)"; const char* ARG_TEMPLATE = R"(const %s& %s)"; @@ -263,9 +263,10 @@ GenerateOpFunctions(const std::string& module_name) { input_args_num++; const auto in_cast_type = input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; + auto dispensable = input.dispensable() ? "true" : "false"; ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name, - arg_idx++, TempName(in_name)); + arg_idx++, TempName(in_name), dispensable); if (input.dispensable()) { const auto in_template = input.duplicable() -- GitLab