未验证 提交 286c2e0e 编写于 作者: W Wilber 提交者: GitHub

error message enhancement for py_func op. (#23565)

error message enhancement for py_func op. 
上级 94a3789f
......@@ -66,7 +66,7 @@ void FusionRepeatedFCReluOp::InferShape(
for (size_t i = 1; i < sz; ++i) {
PADDLE_ENFORCE_EQ(w_dims[i].size(), 2,
platform::errors::InvalidArgument(
"Every weight shape size should be 2., but received "
"Every weight shape size should be 2, but received "
"w_dims[%d].size() = %d.",
i, w_dims[i].size()));
PADDLE_ENFORCE_EQ(
......
......@@ -42,7 +42,11 @@ size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) {
// Returning py::object would cause reference count increasing
// but without GIL, reference count in Python may not be safe
static py::object *GetPythonCallableObject(size_t i) {
PADDLE_ENFORCE_LT(i, g_py_callables.size(), "Invalid python callable id");
PADDLE_ENFORCE_LT(
i, g_py_callables.size(),
platform::errors::InvalidArgument(
"Invalid python callable id %d, which should be less than %d.", i,
g_py_callables.size()));
return &g_py_callables[i];
}
......@@ -71,10 +75,27 @@ static void CallPythonFunc(py::object *callable,
// Python function has no return values or returns None
// In this case, ret_num = 1 && ret[0] == None && out_num should be 0
// Otherwise, ret_num must be equal to out_num
PADDLE_ENFORCE(
ret_num == 1 && out_num == 0 &&
py::cast<framework::LoDTensor *>(ret_tuple[0]) == nullptr,
"Output number not match. Expected %d, actual %d", out_num, ret_num);
PADDLE_ENFORCE_EQ(ret_num == 1, true,
platform::errors::InvalidArgument(
"Python function has no return values or returns "
"None. In this case, ret_num = 1 && ret[0] == None "
"&& out_num should be 0. But ret_num is %d",
ret_num));
PADDLE_ENFORCE_EQ(
out_num == 0, true,
platform::errors::InvalidArgument(
"Python function has no return values or returns None. In "
"this case, ret_num = 1 && ret[0] == None && out_num should "
"be 0. But out_num is %d",
out_num));
PADDLE_ENFORCE_EQ(
py::cast<framework::LoDTensor *>(ret_tuple[0]) == nullptr, true,
platform::errors::InvalidArgument(
"Python function has no return values or returns None. In "
"this case, ret_num = 1 && ret[0] == None && out_num should "
"be 0. But ret[0] is not None"));
}
for (size_t i = 0; i < out_num; ++i) {
......@@ -85,7 +106,8 @@ static void CallPythonFunc(py::object *callable,
try {
auto *py_out_tensor = py::cast<framework::LoDTensor *>(ret_tuple[i]);
PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
"Output tensor %d should not be nullptr", i);
platform::errors::InvalidArgument(
"Output tensor %d should not be nullptr", i));
out->set_lod(py_out_tensor->lod());
out->ShareDataWith(*py_out_tensor);
} catch (py::cast_error &) {
......@@ -105,10 +127,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
* X or Out can be empty, so that py_func can be more flexible
* to support Python functions with no input or no output
*/
PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist");
PADDLE_ENFORCE_GE(boost::get<int>(ctx->GetAttr(kForwardPythonCallableId)),
0, "Function id cannot be less than 0");
PADDLE_ENFORCE_EQ(
has_in || has_out, true,
platform::errors::InvalidArgument("Input(X) or Output(Out) must exist, "
"but has_in is %d, has_out is %d.",
has_in, has_out));
PADDLE_ENFORCE_GE(
boost::get<int>(ctx->GetAttr(kForwardPythonCallableId)), 0,
platform::errors::InvalidArgument(
"Function id cannot be less than 0, but received value is %d.",
boost::get<int>(ctx->GetAttr(kForwardPythonCallableId))));
if (!has_out) return;
......@@ -128,10 +157,12 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
size_t len = out_var_name.size() - kGradVarSuffix.size();
if (out_var_name.substr(len) == kGradVarSuffix) {
auto fwd_var_name = out_var_name.substr(0, len);
PADDLE_ENFORCE(ctx->HasVar(out_var_name),
"Backward variable %s not found", out_var_name);
PADDLE_ENFORCE(ctx->HasVar(fwd_var_name),
"Backward variable %s not found", fwd_var_name);
PADDLE_ENFORCE_EQ(ctx->HasVar(out_var_name), true,
platform::errors::InvalidArgument(
"Backward variable %s not found", out_var_name));
PADDLE_ENFORCE_EQ(ctx->HasVar(fwd_var_name), true,
platform::errors::InvalidArgument(
"Backward variable %s not found", fwd_var_name));
VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
<< fwd_var_name << ")";
......@@ -147,8 +178,9 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
class PyFuncOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"Infer shape cannot be called in runtime.");
PADDLE_ENFORCE_EQ(!ctx->IsRuntime(), true,
platform::errors::InvalidArgument(
"Infer shape cannot be called in runtime."));
}
};
......
......@@ -12820,6 +12820,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
# [7, 5, 2]] [1, 3, 3]] [8, 8, 5]], dtype=int32)]
"""
helper = LayerHelper('py_func', **locals())
check_type(x, 'X', (list, tuple, Variable, type(None)), 'py_func')
if x is None:
x = []
elif isinstance(x, Variable):
......@@ -12828,7 +12829,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
x = list(x)
elif not isinstance(x, (list, tuple, Variable)):
raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)')
check_type(out, 'Out', (list, tuple, Variable, type(None)), 'py_func')
if out is None:
out_list = []
elif isinstance(out, Variable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册