diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 5b612677da3554f17ab3ac29ddc241eee5f7c768..ce1ec507307a2721e641ac15425c6a2321e514c7 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -266,7 +266,7 @@ inline std::string GetErrorSumaryString(StrType&& what, const char* file, std::ostringstream sout; sout << "\n----------------------\nError Message " "Summary:\n----------------------\n"; - sout << string::Sprintf("%s at (%s:%d)", std::forward(what), file, + sout << string::Sprintf("%s (at %s:%d)", std::forward(what), file, line) << std::endl; return sout.str(); diff --git a/paddle/fluid/pybind/op_function.h b/paddle/fluid/pybind/op_function.h index 597ead9327e233df785b58437afce8fa75a058c3..b33555759f8ff14b59aec87b1ec8b6a0447025b1 100644 --- a/paddle/fluid/pybind/op_function.h +++ b/paddle/fluid/pybind/op_function.h @@ -18,9 +18,11 @@ #include #include #include + #include #include #include + #include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/variable.h" @@ -31,6 +33,63 @@ namespace py = pybind11; namespace paddle { namespace pybind { + +static inline std::shared_ptr CastPyHandleToVarBase( + const std::string& op_type, const std::string& arg_name, int arg_idx, + const py::handle& handle) { + PyObject* py_obj = handle.ptr(); // get underlying PyObject + if (!py_obj || py_obj == Py_None) { + return nullptr; + } + try { + return py::cast>(py::handle(py_obj)); + } catch (py::cast_error&) { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be Tensor, but got " + "%s", + op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name)); + } +} + +static inline std::vector> +CastPyHandleToVarBaseList(const std::string& op_type, + const std::string& arg_name, int arg_idx, + const py::handle& handle) { + PyObject* py_obj = handle.ptr(); // get underlying PyObject + if (!py_obj || py_obj == Py_None) { + return {}; + } + std::vector> result; + if (PyList_Check(py_obj) || PyTuple_Check(py_obj)) { + auto size = PyTuple_Check(py_obj) ? PyTuple_GET_SIZE(py_obj) + : PyList_GET_SIZE(py_obj); + for (auto i = 0; i < size; ++i) { + PyObject* item = PyTuple_Check(py_obj) ? PyTuple_GET_ITEM(py_obj, i) + : PyList_GET_ITEM(py_obj, i); + if (!item || item == Py_None) { + result.emplace_back(nullptr); + continue; + } + try { + result.emplace_back( + py::cast>(py::handle(item))); + } catch (py::cast_error&) { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of " + "Tensors, but " + "got %s in list (item %d)", + op_type, arg_name, arg_idx, Py_TYPE(item)->tp_name, i)); + } + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of Tensors, but got " + "%s", + op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name)); + } + return result; +} // namespace pybind + static inline void ConstructAttrMapFromPyArgs(framework::AttributeMap* attrs, const py::args& args) { PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index b32f5e8847d30fc785587541ccdc74d99d2b025c..93ba9feedf95b0e342cda8a40f83ba9ca471858f 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -116,8 +116,19 @@ const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"( const char* ARG_OUT_NUM = R"(%sNum)"; const char* ARG_OUT_NUM_TYPE = R"(size_t )"; -const char* VAR_TYPE = R"(std::shared_ptr)"; -const char* VAR_LIST_TYPE = R"(std::vector>)"; +const char* IN_VAR_TYPE = R"(py::handle)"; +const char* IN_VAR_LIST_TYPE = R"(py::handle)"; + +const char* OUT_VAR_TYPE = R"(std::shared_ptr)"; +const char* OUT_VAR_LIST_TYPE = R"(std::vector>)"; + +const char* CAST_VAR_TEMPLATE = R"( + auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s);)"; + +const char* CAST_VAR_LIST_TEMPLATE = R"( + auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s);)"; + + const char* ARG_TEMPLATE = R"(const %s& %s)"; const char* RETURN_TUPLE_TYPE = R"(std::tuple<%s>)"; @@ -133,6 +144,7 @@ const char* OP_FUNCTION_TEMPLATE = R"( %s %s(%s) { + %s framework::AttributeMap attrs; ConstructAttrMapFromPyArgs(&attrs, args); { @@ -164,6 +176,10 @@ static inline bool FindPassingOutsMap(const std::string& op_type, return op_passing_outs_map[op_type].count(out_name); } +static inline std::string TempName(const std::string& name) { + return name + '_'; +} + static std::tuple, std::vector> GenerateOpFunctions(const std::string& module_name) { auto& op_info_map = paddle::framework::OpInfoMap::Instance().map(); @@ -187,16 +203,24 @@ GenerateOpFunctions(const std::string& module_name) { std::string ins_initializer = "{"; std::string ins_initializer_with_null = ""; std::string py_arg = ""; + int arg_idx = 0; + std::string ins_cast_str = ""; for (auto& input : op_proto->inputs()) { auto& in_name = input.name(); // skip those dispensable inputs, like ResidualData in conv2d if (input.dispensable() && !FindInsMap(op_type, in_name)) { continue; } - const auto in_type = input.duplicable() ? VAR_LIST_TYPE : VAR_TYPE; - auto input_arg = paddle::string::Sprintf(ARG_TEMPLATE, in_type, in_name); + const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE; + auto input_arg = + paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name)); input_args += input_arg; input_args += ","; + const auto in_cast_type = + input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; + ins_cast_str += + paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name, + arg_idx++, TempName(in_name)); if (input.dispensable()) { const auto in_template = input.duplicable() @@ -235,7 +259,8 @@ GenerateOpFunctions(const std::string& module_name) { if (output.dispensable() && !FindOutsMap(op_type, out_name)) { continue; } - const auto out_type = output.duplicable() ? VAR_LIST_TYPE : VAR_TYPE; + const auto out_type = + output.duplicable() ? OUT_VAR_LIST_TYPE : OUT_VAR_TYPE; const auto return_template = output.duplicable() ? RETURN_LIST_TEMPLATE : RETURN_TEMPLATE; if (FindPassingOutsMap(op_type, out_name)) { @@ -309,7 +334,7 @@ GenerateOpFunctions(const std::string& module_name) { // generate op funtcion body auto op_function_str = paddle::string::Sprintf( OP_FUNCTION_TEMPLATE, return_type, func_name, function_args, - outs_initializer, ins_initializer, + ins_cast_str, outs_initializer, ins_initializer, ins_initializer_with_null + outs_initializer_with_null, op_type, return_str);