未验证 提交 286c2e0e 编写于 作者: W Wilber 提交者: GitHub

error message enhancement for py_func op. (#23565)

error message enhancement for py_func op. 
上级 94a3789f
...@@ -66,7 +66,7 @@ void FusionRepeatedFCReluOp::InferShape( ...@@ -66,7 +66,7 @@ void FusionRepeatedFCReluOp::InferShape(
for (size_t i = 1; i < sz; ++i) { for (size_t i = 1; i < sz; ++i) {
PADDLE_ENFORCE_EQ(w_dims[i].size(), 2, PADDLE_ENFORCE_EQ(w_dims[i].size(), 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Every weight shape size should be 2., but received " "Every weight shape size should be 2, but received "
"w_dims[%d].size() = %d.", "w_dims[%d].size() = %d.",
i, w_dims[i].size())); i, w_dims[i].size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -42,7 +42,11 @@ size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) { ...@@ -42,7 +42,11 @@ size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) {
// Returning py::object would cause reference count increasing // Returning py::object would cause reference count increasing
// but without GIL, reference count in Python may not be safe // but without GIL, reference count in Python may not be safe
static py::object *GetPythonCallableObject(size_t i) { static py::object *GetPythonCallableObject(size_t i) {
PADDLE_ENFORCE_LT(i, g_py_callables.size(), "Invalid python callable id"); PADDLE_ENFORCE_LT(
i, g_py_callables.size(),
platform::errors::InvalidArgument(
"Invalid python callable id %d, which should be less than %d.", i,
g_py_callables.size()));
return &g_py_callables[i]; return &g_py_callables[i];
} }
...@@ -71,10 +75,27 @@ static void CallPythonFunc(py::object *callable, ...@@ -71,10 +75,27 @@ static void CallPythonFunc(py::object *callable,
// Python function has no return values or returns None // Python function has no return values or returns None
// In this case, ret_num = 1 && ret[0] == None && out_num should be 0 // In this case, ret_num = 1 && ret[0] == None && out_num should be 0
// Otherwise, ret_num must be equal to out_num // Otherwise, ret_num must be equal to out_num
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(ret_num == 1, true,
ret_num == 1 && out_num == 0 && platform::errors::InvalidArgument(
py::cast<framework::LoDTensor *>(ret_tuple[0]) == nullptr, "Python function has no return values or returns "
"Output number not match. Expected %d, actual %d", out_num, ret_num); "None. In this case, ret_num = 1 && ret[0] == None "
"&& out_num should be 0. But ret_num is %d",
ret_num));
PADDLE_ENFORCE_EQ(
out_num == 0, true,
platform::errors::InvalidArgument(
"Python function has no return values or returns None. In "
"this case, ret_num = 1 && ret[0] == None && out_num should "
"be 0. But out_num is %d",
out_num));
PADDLE_ENFORCE_EQ(
py::cast<framework::LoDTensor *>(ret_tuple[0]) == nullptr, true,
platform::errors::InvalidArgument(
"Python function has no return values or returns None. In "
"this case, ret_num = 1 && ret[0] == None && out_num should "
"be 0. But ret[0] is not None"));
} }
for (size_t i = 0; i < out_num; ++i) { for (size_t i = 0; i < out_num; ++i) {
...@@ -85,7 +106,8 @@ static void CallPythonFunc(py::object *callable, ...@@ -85,7 +106,8 @@ static void CallPythonFunc(py::object *callable,
try { try {
auto *py_out_tensor = py::cast<framework::LoDTensor *>(ret_tuple[i]); auto *py_out_tensor = py::cast<framework::LoDTensor *>(ret_tuple[i]);
PADDLE_ENFORCE_NOT_NULL(py_out_tensor, PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
"Output tensor %d should not be nullptr", i); platform::errors::InvalidArgument(
"Output tensor %d should not be nullptr", i));
out->set_lod(py_out_tensor->lod()); out->set_lod(py_out_tensor->lod());
out->ShareDataWith(*py_out_tensor); out->ShareDataWith(*py_out_tensor);
} catch (py::cast_error &) { } catch (py::cast_error &) {
...@@ -105,10 +127,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference { ...@@ -105,10 +127,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
* X or Out can be empty, so that py_func can be more flexible * X or Out can be empty, so that py_func can be more flexible
* to support Python functions with no input or no output * to support Python functions with no input or no output
*/ */
PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist"); PADDLE_ENFORCE_EQ(
has_in || has_out, true,
PADDLE_ENFORCE_GE(boost::get<int>(ctx->GetAttr(kForwardPythonCallableId)), platform::errors::InvalidArgument("Input(X) or Output(Out) must exist, "
0, "Function id cannot be less than 0"); "but has_in is %d, has_out is %d.",
has_in, has_out));
PADDLE_ENFORCE_GE(
boost::get<int>(ctx->GetAttr(kForwardPythonCallableId)), 0,
platform::errors::InvalidArgument(
"Function id cannot be less than 0, but received value is %d.",
boost::get<int>(ctx->GetAttr(kForwardPythonCallableId))));
if (!has_out) return; if (!has_out) return;
...@@ -128,10 +157,12 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference { ...@@ -128,10 +157,12 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
size_t len = out_var_name.size() - kGradVarSuffix.size(); size_t len = out_var_name.size() - kGradVarSuffix.size();
if (out_var_name.substr(len) == kGradVarSuffix) { if (out_var_name.substr(len) == kGradVarSuffix) {
auto fwd_var_name = out_var_name.substr(0, len); auto fwd_var_name = out_var_name.substr(0, len);
PADDLE_ENFORCE(ctx->HasVar(out_var_name), PADDLE_ENFORCE_EQ(ctx->HasVar(out_var_name), true,
"Backward variable %s not found", out_var_name); platform::errors::InvalidArgument(
PADDLE_ENFORCE(ctx->HasVar(fwd_var_name), "Backward variable %s not found", out_var_name));
"Backward variable %s not found", fwd_var_name); PADDLE_ENFORCE_EQ(ctx->HasVar(fwd_var_name), true,
platform::errors::InvalidArgument(
"Backward variable %s not found", fwd_var_name));
VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input(" VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
<< fwd_var_name << ")"; << fwd_var_name << ")";
...@@ -147,8 +178,9 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference { ...@@ -147,8 +178,9 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
class PyFuncOpShapeInference : public framework::InferShapeBase { class PyFuncOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(!ctx->IsRuntime(), PADDLE_ENFORCE_EQ(!ctx->IsRuntime(), true,
"Infer shape cannot be called in runtime."); platform::errors::InvalidArgument(
"Infer shape cannot be called in runtime."));
} }
}; };
......
...@@ -12820,6 +12820,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): ...@@ -12820,6 +12820,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
# [7, 5, 2]] [1, 3, 3]] [8, 8, 5]], dtype=int32)] # [7, 5, 2]] [1, 3, 3]] [8, 8, 5]], dtype=int32)]
""" """
helper = LayerHelper('py_func', **locals()) helper = LayerHelper('py_func', **locals())
check_type(x, 'X', (list, tuple, Variable, type(None)), 'py_func')
if x is None: if x is None:
x = [] x = []
elif isinstance(x, Variable): elif isinstance(x, Variable):
...@@ -12828,7 +12829,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): ...@@ -12828,7 +12829,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
x = list(x) x = list(x)
elif not isinstance(x, (list, tuple, Variable)): elif not isinstance(x, (list, tuple, Variable)):
raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)') raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)')
check_type(out, 'Out', (list, tuple, Variable, type(None)), 'py_func')
if out is None: if out is None:
out_list = [] out_list = []
elif isinstance(out, Variable): elif isinstance(out, Variable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册