未验证 提交 049ac56c 编写于 作者: L Leo Chen 提交者: GitHub

Print user-friendly error message in core.ops [part 2] (#26377)

上级 fd0051b4
...@@ -90,15 +90,36 @@ CastPyHandleToVarBaseList(const std::string& op_type, ...@@ -90,15 +90,36 @@ CastPyHandleToVarBaseList(const std::string& op_type,
return result; return result;
} // namespace pybind } // namespace pybind
static inline void ConstructAttrMapFromPyArgs(framework::AttributeMap* attrs, static inline void ConstructAttrMapFromPyArgs(const std::string& op_type,
int start_idx,
framework::AttributeMap* attrs,
const py::args& args) { const py::args& args) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
args.size() % 2, 0, args.size() % 2, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The number of arguments for arributes should be even.")); "The number of arguments for arributes should be even."));
for (size_t i = 0; i < args.size(); i += 2) { for (size_t i = 0; i < args.size(); i += 2) {
auto name = args[i].cast<std::string>(); std::string name;
auto value = args[i + 1].cast<framework::Attribute>(); framework::Attribute value;
try {
name = args[i].cast<std::string>();
} catch (std::exception& e) {
PyObject* py_obj = args[i].ptr(); // get underlying PyObject
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be str, but got "
"%s",
op_type, start_idx + i, Py_TYPE(py_obj)->tp_name));
}
try {
value = args[i + 1].cast<framework::Attribute>();
} catch (std::exception& e) {
PyObject* py_obj = args[i + 1].ptr(); // get underlying PyObject
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"Attribute type (one of str, bool, int, int64, float, or list of "
"them), but got %s",
op_type, start_idx + i + 1, Py_TYPE(py_obj)->tp_name));
}
(*attrs)[name] = value; (*attrs)[name] = value;
} }
} }
......
...@@ -146,7 +146,7 @@ R"( ...@@ -146,7 +146,7 @@ R"(
{ {
%s %s
framework::AttributeMap attrs; framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs(&attrs, args); ConstructAttrMapFromPyArgs("%s", %d, &attrs, args);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
auto tracer = imperative::GetCurrentTracer(); auto tracer = imperative::GetCurrentTracer();
...@@ -204,6 +204,7 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -204,6 +204,7 @@ GenerateOpFunctions(const std::string& module_name) {
std::string ins_initializer_with_null = ""; std::string ins_initializer_with_null = "";
std::string py_arg = ""; std::string py_arg = "";
int arg_idx = 0; int arg_idx = 0;
int input_args_num = 0;
std::string ins_cast_str = ""; std::string ins_cast_str = "";
for (auto& input : op_proto->inputs()) { for (auto& input : op_proto->inputs()) {
auto& in_name = input.name(); auto& in_name = input.name();
...@@ -216,6 +217,7 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -216,6 +217,7 @@ GenerateOpFunctions(const std::string& module_name) {
paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name)); paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name));
input_args += input_arg; input_args += input_arg;
input_args += ","; input_args += ",";
input_args_num++;
const auto in_cast_type = const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
ins_cast_str += ins_cast_str +=
...@@ -269,6 +271,7 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -269,6 +271,7 @@ GenerateOpFunctions(const std::string& module_name) {
} }
input_args += out_type; input_args += out_type;
input_args += out_name; input_args += out_name;
input_args_num++;
if (output.dispensable()) { if (output.dispensable()) {
const auto out_template = const auto out_template =
...@@ -295,6 +298,7 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -295,6 +298,7 @@ GenerateOpFunctions(const std::string& module_name) {
auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name); auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name);
input_args += ARG_OUT_NUM_TYPE; input_args += ARG_OUT_NUM_TYPE;
input_args += out_num_str; input_args += out_num_str;
input_args_num++;
outs_initializer += paddle::string::Sprintf( outs_initializer += paddle::string::Sprintf(
OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str); OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str);
} else { } else {
...@@ -334,9 +338,9 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -334,9 +338,9 @@ GenerateOpFunctions(const std::string& module_name) {
// generate op funtcion body // generate op funtcion body
auto op_function_str = paddle::string::Sprintf( auto op_function_str = paddle::string::Sprintf(
OP_FUNCTION_TEMPLATE, return_type, func_name, function_args, OP_FUNCTION_TEMPLATE, return_type, func_name, function_args,
ins_cast_str, outs_initializer, ins_initializer, ins_cast_str, op_type, input_args_num, outs_initializer,
ins_initializer_with_null + outs_initializer_with_null, op_type, ins_initializer, ins_initializer_with_null + outs_initializer_with_null,
return_str); op_type, return_str);
// generate pybind item // generate pybind item
auto bind_function_str = paddle::string::Sprintf( auto bind_function_str = paddle::string::Sprintf(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册