diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc index 83185639d6219ead9a594d036de542395de8eb0c..d65206934bcdaf8113b090e24adf15cb50e6e3b7 100644 --- a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc @@ -66,7 +66,7 @@ void FusionRepeatedFCReluOp::InferShape( for (size_t i = 1; i < sz; ++i) { PADDLE_ENFORCE_EQ(w_dims[i].size(), 2, platform::errors::InvalidArgument( - "Every weight shape size should be 2., but received " + "Every weight shape size should be 2, but received " "w_dims[%d].size() = %d.", i, w_dims[i].size())); PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index 5300e807472d3bb243dc198c0bfd1bc572538015..b9c32304353e0e715cd79ea9d604cdadf6fde44f 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -42,7 +42,11 @@ size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) { // Returning py::object would cause reference count increasing // but without GIL, reference count in Python may not be safe static py::object *GetPythonCallableObject(size_t i) { - PADDLE_ENFORCE_LT(i, g_py_callables.size(), "Invalid python callable id"); + PADDLE_ENFORCE_LT( + i, g_py_callables.size(), + platform::errors::InvalidArgument( + "Invalid python callable id %d, which should be less than %d.", i, + g_py_callables.size())); return &g_py_callables[i]; } @@ -71,10 +75,27 @@ static void CallPythonFunc(py::object *callable, // Python function has no return values or returns None // In this case, ret_num = 1 && ret[0] == None && out_num should be 0 // Otherwise, ret_num must be equal to out_num - PADDLE_ENFORCE( - ret_num == 1 && out_num == 0 && - py::cast(ret_tuple[0]) == nullptr, - "Output number not match. Expected %d, actual %d", out_num, ret_num); + PADDLE_ENFORCE_EQ(ret_num == 1, true, + platform::errors::InvalidArgument( + "Python function has no return values or returns " + "None. In this case, ret_num = 1 && ret[0] == None " + "&& out_num should be 0. But ret_num is %d", + ret_num)); + + PADDLE_ENFORCE_EQ( + out_num == 0, true, + platform::errors::InvalidArgument( + "Python function has no return values or returns None. In " + "this case, ret_num = 1 && ret[0] == None && out_num should " + "be 0. But out_num is %d", + out_num)); + + PADDLE_ENFORCE_EQ( + py::cast(ret_tuple[0]) == nullptr, true, + platform::errors::InvalidArgument( + "Python function has no return values or returns None. In " + "this case, ret_num = 1 && ret[0] == None && out_num should " + "be 0. But ret[0] is not None")); } for (size_t i = 0; i < out_num; ++i) { @@ -85,7 +106,8 @@ static void CallPythonFunc(py::object *callable, try { auto *py_out_tensor = py::cast(ret_tuple[i]); PADDLE_ENFORCE_NOT_NULL(py_out_tensor, - "Output tensor %d should not be nullptr", i); + platform::errors::InvalidArgument( + "Output tensor %d should not be nullptr", i)); out->set_lod(py_out_tensor->lod()); out->ShareDataWith(*py_out_tensor); } catch (py::cast_error &) { @@ -105,10 +127,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference { * X or Out can be empty, so that py_func can be more flexible * to support Python functions with no input or no output */ - PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist"); - - PADDLE_ENFORCE_GE(boost::get(ctx->GetAttr(kForwardPythonCallableId)), - 0, "Function id cannot be less than 0"); + PADDLE_ENFORCE_EQ( + has_in || has_out, true, + platform::errors::InvalidArgument("Input(X) or Output(Out) must exist, " + "but has_in is %d, has_out is %d.", + has_in, has_out)); + + PADDLE_ENFORCE_GE( + boost::get(ctx->GetAttr(kForwardPythonCallableId)), 0, + platform::errors::InvalidArgument( + "Function id cannot be less than 0, but received value is %d.", + boost::get(ctx->GetAttr(kForwardPythonCallableId)))); if (!has_out) return; @@ -128,10 +157,12 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference { size_t len = out_var_name.size() - kGradVarSuffix.size(); if (out_var_name.substr(len) == kGradVarSuffix) { auto fwd_var_name = out_var_name.substr(0, len); - PADDLE_ENFORCE(ctx->HasVar(out_var_name), - "Backward variable %s not found", out_var_name); - PADDLE_ENFORCE(ctx->HasVar(fwd_var_name), - "Backward variable %s not found", fwd_var_name); + PADDLE_ENFORCE_EQ(ctx->HasVar(out_var_name), true, + platform::errors::InvalidArgument( + "Backward variable %s not found", out_var_name)); + PADDLE_ENFORCE_EQ(ctx->HasVar(fwd_var_name), true, + platform::errors::InvalidArgument( + "Backward variable %s not found", fwd_var_name)); VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input(" << fwd_var_name << ")"; @@ -147,8 +178,9 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference { class PyFuncOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(!ctx->IsRuntime(), - "Infer shape cannot be called in runtime."); + PADDLE_ENFORCE_EQ(!ctx->IsRuntime(), true, + platform::errors::InvalidArgument( + "Infer shape cannot be called in runtime.")); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 5d41729f631b4f267aff683b7bd7882c4784f7cb..9b2d1d233b2dc50eaaffc40878a988cd68f1926c 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12820,6 +12820,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): # [7, 5, 2]] [1, 3, 3]] [8, 8, 5]], dtype=int32)] """ helper = LayerHelper('py_func', **locals()) + check_type(x, 'X', (list, tuple, Variable, type(None)), 'py_func') if x is None: x = [] elif isinstance(x, Variable): @@ -12828,7 +12829,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): x = list(x) elif not isinstance(x, (list, tuple, Variable)): raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)') - + check_type(out, 'Out', (list, tuple, Variable, type(None)), 'py_func') if out is None: out_list = [] elif isinstance(out, Variable):