diff --git a/paddle/fluid/pybind/op_function.h b/paddle/fluid/pybind/op_function.h index b33555759f8ff14b59aec87b1ec8b6a0447025b1..70b321f658cd2cf1bd43cd6440bf83e1f4dab140 100644 --- a/paddle/fluid/pybind/op_function.h +++ b/paddle/fluid/pybind/op_function.h @@ -90,15 +90,36 @@ CastPyHandleToVarBaseList(const std::string& op_type, return result; } // namespace pybind -static inline void ConstructAttrMapFromPyArgs(framework::AttributeMap* attrs, +static inline void ConstructAttrMapFromPyArgs(const std::string& op_type, + int start_idx, + framework::AttributeMap* attrs, const py::args& args) { PADDLE_ENFORCE_EQ( args.size() % 2, 0, platform::errors::InvalidArgument( "The number of arguments for arributes should be even.")); for (size_t i = 0; i < args.size(); i += 2) { - auto name = args[i].cast(); - auto value = args[i + 1].cast(); + std::string name; + framework::Attribute value; + try { + name = args[i].cast(); + } catch (std::exception& e) { + PyObject* py_obj = args[i].ptr(); // get underlying PyObject + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument (position %d) must be str, but got " + "%s", + op_type, start_idx + i, Py_TYPE(py_obj)->tp_name)); + } + try { + value = args[i + 1].cast(); + } catch (std::exception& e) { + PyObject* py_obj = args[i + 1].ptr(); // get underlying PyObject + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument (position %d) must be " + "Attribute type (one of str, bool, int, int64, float, or list of " + "them), but got %s", + op_type, start_idx + i + 1, Py_TYPE(py_obj)->tp_name)); + } (*attrs)[name] = value; } } diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 93ba9feedf95b0e342cda8a40f83ba9ca471858f..89770ccc8cec1b0c5fcfc4c1033a691cc674566e 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -146,7 +146,7 @@ R"( { %s framework::AttributeMap attrs; - ConstructAttrMapFromPyArgs(&attrs, args); + ConstructAttrMapFromPyArgs("%s", %d, &attrs, args); { py::gil_scoped_release release; auto tracer = imperative::GetCurrentTracer(); @@ -204,6 +204,7 @@ GenerateOpFunctions(const std::string& module_name) { std::string ins_initializer_with_null = ""; std::string py_arg = ""; int arg_idx = 0; + int input_args_num = 0; std::string ins_cast_str = ""; for (auto& input : op_proto->inputs()) { auto& in_name = input.name(); @@ -216,6 +217,7 @@ GenerateOpFunctions(const std::string& module_name) { paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name)); input_args += input_arg; input_args += ","; + input_args_num++; const auto in_cast_type = input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; ins_cast_str += @@ -269,6 +271,7 @@ GenerateOpFunctions(const std::string& module_name) { } input_args += out_type; input_args += out_name; + input_args_num++; if (output.dispensable()) { const auto out_template = @@ -295,6 +298,7 @@ GenerateOpFunctions(const std::string& module_name) { auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name); input_args += ARG_OUT_NUM_TYPE; input_args += out_num_str; + input_args_num++; outs_initializer += paddle::string::Sprintf( OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str); } else { @@ -334,9 +338,9 @@ GenerateOpFunctions(const std::string& module_name) { // generate op funtcion body auto op_function_str = paddle::string::Sprintf( OP_FUNCTION_TEMPLATE, return_type, func_name, function_args, - ins_cast_str, outs_initializer, ins_initializer, - ins_initializer_with_null + outs_initializer_with_null, op_type, - return_str); + ins_cast_str, op_type, input_args_num, outs_initializer, + ins_initializer, ins_initializer_with_null + outs_initializer_with_null, + op_type, return_str); // generate pybind item auto bind_function_str = paddle::string::Sprintf(