未验证 提交 049ac56c 编写于 作者: L Leo Chen 提交者: GitHub

Print user-friendly error message in core.ops [part 2] (#26377)

上级 fd0051b4
......@@ -90,15 +90,36 @@ CastPyHandleToVarBaseList(const std::string& op_type,
return result;
} // namespace pybind
static inline void ConstructAttrMapFromPyArgs(framework::AttributeMap* attrs,
static inline void ConstructAttrMapFromPyArgs(const std::string& op_type,
int start_idx,
framework::AttributeMap* attrs,
const py::args& args) {
PADDLE_ENFORCE_EQ(
args.size() % 2, 0,
platform::errors::InvalidArgument(
"The number of arguments for arributes should be even."));
for (size_t i = 0; i < args.size(); i += 2) {
auto name = args[i].cast<std::string>();
auto value = args[i + 1].cast<framework::Attribute>();
std::string name;
framework::Attribute value;
try {
name = args[i].cast<std::string>();
} catch (std::exception& e) {
PyObject* py_obj = args[i].ptr(); // get underlying PyObject
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be str, but got "
"%s",
op_type, start_idx + i, Py_TYPE(py_obj)->tp_name));
}
try {
value = args[i + 1].cast<framework::Attribute>();
} catch (std::exception& e) {
PyObject* py_obj = args[i + 1].ptr(); // get underlying PyObject
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"Attribute type (one of str, bool, int, int64, float, or list of "
"them), but got %s",
op_type, start_idx + i + 1, Py_TYPE(py_obj)->tp_name));
}
(*attrs)[name] = value;
}
}
......
......@@ -146,7 +146,7 @@ R"(
{
%s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs(&attrs, args);
ConstructAttrMapFromPyArgs("%s", %d, &attrs, args);
{
py::gil_scoped_release release;
auto tracer = imperative::GetCurrentTracer();
......@@ -204,6 +204,7 @@ GenerateOpFunctions(const std::string& module_name) {
std::string ins_initializer_with_null = "";
std::string py_arg = "";
int arg_idx = 0;
int input_args_num = 0;
std::string ins_cast_str = "";
for (auto& input : op_proto->inputs()) {
auto& in_name = input.name();
......@@ -216,6 +217,7 @@ GenerateOpFunctions(const std::string& module_name) {
paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name));
input_args += input_arg;
input_args += ",";
input_args_num++;
const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
ins_cast_str +=
......@@ -269,6 +271,7 @@ GenerateOpFunctions(const std::string& module_name) {
}
input_args += out_type;
input_args += out_name;
input_args_num++;
if (output.dispensable()) {
const auto out_template =
......@@ -295,6 +298,7 @@ GenerateOpFunctions(const std::string& module_name) {
auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name);
input_args += ARG_OUT_NUM_TYPE;
input_args += out_num_str;
input_args_num++;
outs_initializer += paddle::string::Sprintf(
OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str);
} else {
......@@ -334,9 +338,9 @@ GenerateOpFunctions(const std::string& module_name) {
// generate op funtcion body
auto op_function_str = paddle::string::Sprintf(
OP_FUNCTION_TEMPLATE, return_type, func_name, function_args,
ins_cast_str, outs_initializer, ins_initializer,
ins_initializer_with_null + outs_initializer_with_null, op_type,
return_str);
ins_cast_str, op_type, input_args_num, outs_initializer,
ins_initializer, ins_initializer_with_null + outs_initializer_with_null,
op_type, return_str);
// generate pybind item
auto bind_function_str = paddle::string::Sprintf(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册