未验证 提交 672578a7 编写于 作者: L Leo Chen 提交者: GitHub

Print user-friendly error message in core.ops (#26261)

* print user-friendly error message

* adjust error sumary
上级 d7bdc9fe
......@@ -266,7 +266,7 @@ inline std::string GetErrorSumaryString(StrType&& what, const char* file,
std::ostringstream sout;
sout << "\n----------------------\nError Message "
"Summary:\n----------------------\n";
sout << string::Sprintf("%s at (%s:%d)", std::forward<StrType>(what), file,
sout << string::Sprintf("%s (at %s:%d)", std::forward<StrType>(what), file,
line)
<< std::endl;
return sout.str();
......
......@@ -18,9 +18,11 @@
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/variable.h"
......@@ -31,6 +33,63 @@
namespace py = pybind11;
namespace paddle {
namespace pybind {
static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase(
const std::string& op_type, const std::string& arg_name, int arg_idx,
const py::handle& handle) {
PyObject* py_obj = handle.ptr(); // get underlying PyObject
if (!py_obj || py_obj == Py_None) {
return nullptr;
}
try {
return py::cast<std::shared_ptr<imperative::VarBase>>(py::handle(py_obj));
} catch (py::cast_error&) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got "
"%s",
op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name));
}
}
static inline std::vector<std::shared_ptr<imperative::VarBase>>
CastPyHandleToVarBaseList(const std::string& op_type,
const std::string& arg_name, int arg_idx,
const py::handle& handle) {
PyObject* py_obj = handle.ptr(); // get underlying PyObject
if (!py_obj || py_obj == Py_None) {
return {};
}
std::vector<std::shared_ptr<imperative::VarBase>> result;
if (PyList_Check(py_obj) || PyTuple_Check(py_obj)) {
auto size = PyTuple_Check(py_obj) ? PyTuple_GET_SIZE(py_obj)
: PyList_GET_SIZE(py_obj);
for (auto i = 0; i < size; ++i) {
PyObject* item = PyTuple_Check(py_obj) ? PyTuple_GET_ITEM(py_obj, i)
: PyList_GET_ITEM(py_obj, i);
if (!item || item == Py_None) {
result.emplace_back(nullptr);
continue;
}
try {
result.emplace_back(
py::cast<std::shared_ptr<imperative::VarBase>>(py::handle(item)));
} catch (py::cast_error&) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of "
"Tensors, but "
"got %s in list (item %d)",
op_type, arg_name, arg_idx, Py_TYPE(item)->tp_name, i));
}
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"%s",
op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name));
}
return result;
} // namespace pybind
static inline void ConstructAttrMapFromPyArgs(framework::AttributeMap* attrs,
const py::args& args) {
PADDLE_ENFORCE_EQ(
......
......@@ -116,8 +116,19 @@ const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(
const char* ARG_OUT_NUM = R"(%sNum)";
const char* ARG_OUT_NUM_TYPE = R"(size_t )";
const char* VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";
const char* IN_VAR_TYPE = R"(py::handle)";
const char* IN_VAR_LIST_TYPE = R"(py::handle)";
const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";
const char* CAST_VAR_TEMPLATE = R"(
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s);)";
const char* CAST_VAR_LIST_TEMPLATE = R"(
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s);)";
const char* ARG_TEMPLATE = R"(const %s& %s)";
const char* RETURN_TUPLE_TYPE = R"(std::tuple<%s>)";
......@@ -133,6 +144,7 @@ const char* OP_FUNCTION_TEMPLATE =
R"(
%s %s(%s)
{
%s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs(&attrs, args);
{
......@@ -164,6 +176,10 @@ static inline bool FindPassingOutsMap(const std::string& op_type,
return op_passing_outs_map[op_type].count(out_name);
}
static inline std::string TempName(const std::string& name) {
return name + '_';
}
static std::tuple<std::vector<std::string>, std::vector<std::string>>
GenerateOpFunctions(const std::string& module_name) {
auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();
......@@ -187,16 +203,24 @@ GenerateOpFunctions(const std::string& module_name) {
std::string ins_initializer = "{";
std::string ins_initializer_with_null = "";
std::string py_arg = "";
int arg_idx = 0;
std::string ins_cast_str = "";
for (auto& input : op_proto->inputs()) {
auto& in_name = input.name();
// skip those dispensable inputs, like ResidualData in conv2d
if (input.dispensable() && !FindInsMap(op_type, in_name)) {
continue;
}
const auto in_type = input.duplicable() ? VAR_LIST_TYPE : VAR_TYPE;
auto input_arg = paddle::string::Sprintf(ARG_TEMPLATE, in_type, in_name);
const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
auto input_arg =
paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name));
input_args += input_arg;
input_args += ",";
const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
ins_cast_str +=
paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name,
arg_idx++, TempName(in_name));
if (input.dispensable()) {
const auto in_template = input.duplicable()
......@@ -235,7 +259,8 @@ GenerateOpFunctions(const std::string& module_name) {
if (output.dispensable() && !FindOutsMap(op_type, out_name)) {
continue;
}
const auto out_type = output.duplicable() ? VAR_LIST_TYPE : VAR_TYPE;
const auto out_type =
output.duplicable() ? OUT_VAR_LIST_TYPE : OUT_VAR_TYPE;
const auto return_template =
output.duplicable() ? RETURN_LIST_TEMPLATE : RETURN_TEMPLATE;
if (FindPassingOutsMap(op_type, out_name)) {
......@@ -309,7 +334,7 @@ GenerateOpFunctions(const std::string& module_name) {
// generate op funtcion body
auto op_function_str = paddle::string::Sprintf(
OP_FUNCTION_TEMPLATE, return_type, func_name, function_args,
outs_initializer, ins_initializer,
ins_cast_str, outs_initializer, ins_initializer,
ins_initializer_with_null + outs_initializer_with_null, op_type,
return_str);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册