From 71b5f1d2b246702ee2e50184ad23bd5b45a21c35 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Wed, 8 Apr 2020 13:55:50 +0800 Subject: [PATCH] OP (recurrent) error message enhancement (#23481) * OP (recurrent) error message enhancement --- .../controlflow/recurrent_op_helper.cc | 42 ++++++++++++++----- paddle/fluid/operators/recurrent_op.h | 22 +++++++--- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/controlflow/recurrent_op_helper.cc b/paddle/fluid/operators/controlflow/recurrent_op_helper.cc index d2bb68272df..a515a8b907e 100644 --- a/paddle/fluid/operators/controlflow/recurrent_op_helper.cc +++ b/paddle/fluid/operators/controlflow/recurrent_op_helper.cc @@ -72,8 +72,12 @@ static void FindAllOpAndGradOp(const framework::ProgramDesc &program, OpVariantSet &ops = op_and_grad_op->first; OpVariantSet &grad_ops = op_and_grad_op->second; - PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(), - "There are extra grad ops in the graph or program"); + PADDLE_ENFORCE_GE( + ops.size(), grad_ops.size(), + platform::errors::InvalidArgument( + "There are more grad ops than forward ops in the graph or program, " + "the number of ops is %d and the number of grad_ops is %d.", + ops.size(), grad_ops.size())); for (size_t i = 1; i < program.Size(); ++i) { auto &block = program.Block(i); @@ -87,8 +91,12 @@ static void FindAllOpAndGradOp(const framework::ProgramDesc &program, } } - PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(), - "There are extra grad ops in the graph or program"); + PADDLE_ENFORCE_GE( + ops.size(), grad_ops.size(), + platform::errors::InvalidArgument( + "There are more grad ops than forward ops in the graph or program, " + "the number of ops is %d and the number of grad_ops is %d.", + ops.size(), grad_ops.size())); } // Returns GradVarName of input var names @@ -169,7 +177,11 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr( PADDLE_ENFORCE_EQ( fwd_input.size(), in_grads.size(), - "Backward input gradient number does not match forward input number."); + platform::errors::PreconditionNotMet( + "Backward input gradient number does not match forward " + "input number. The number of forward input number is %d and the " + "number of backward input gradient number is %d.", + fwd_input.size(), in_grads.size())); for (size_t i = 0; i < in_grads.size(); ++i) { if (in_grads[i] == framework::kEmptyVarName) { continue; @@ -181,9 +193,13 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr( auto &fwd_param = fwd_op.Inputs().at(RecurrentBase::kParameters); auto ¶m_grads = bwd_op.Outputs().at(framework::GradVarName(RecurrentBase::kParameters)); - PADDLE_ENFORCE_EQ(fwd_param.size(), param_grads.size(), - "Backward parameter gradient number does not match forward " - "parameter number."); + PADDLE_ENFORCE_EQ( + fwd_param.size(), param_grads.size(), + platform::errors::PreconditionNotMet( + "Backward parameter gradient number does not match " + "forward parameter number. The number of forward parameter number is " + "%d and the number of backward parameter gradient is %d.", + fwd_param.size(), param_grads.size())); for (size_t i = 0; i < fwd_param.size(); ++i) { if (param_grads[i] == framework::kEmptyVarName) { continue; @@ -241,12 +257,16 @@ void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( const OpVariant *matched_fwd_op = nullptr; for (auto &fwd_op : recurrent_ops) { if (IsMatchedRecurrentOpAndRecurrentGradOp(fwd_op, bwd_op)) { - PADDLE_ENFORCE(matched_fwd_op == nullptr, - "Found multiple matched recurrent op"); + PADDLE_ENFORCE_EQ(matched_fwd_op, nullptr, + platform::errors::PreconditionNotMet( + "Found multiple recurrent forward op matches " + "recurrent grad op.")); matched_fwd_op = &fwd_op; } } - PADDLE_ENFORCE_NOT_NULL(matched_fwd_op, "Cannot find matched forward op"); + PADDLE_ENFORCE_NOT_NULL(matched_fwd_op, + platform::errors::PreconditionNotMet( + "Cannot find matched forward op.")); SetRecurrentOpAndRecurrentGradOpSkipVarAttr(*matched_fwd_op, bwd_op); recurrent_ops.erase(*matched_fwd_op); } diff --git a/paddle/fluid/operators/recurrent_op.h b/paddle/fluid/operators/recurrent_op.h index a4b21448a60..1ca66527e1b 100644 --- a/paddle/fluid/operators/recurrent_op.h +++ b/paddle/fluid/operators/recurrent_op.h @@ -118,7 +118,10 @@ class RecurrentBase : public framework::OperatorBase { const std::vector &dst_vars, Callback callback, bool is_backward = false) { - PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size()); + PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size(), + platform::errors::InvalidArgument( + "Sizes of source vars and destination vars are not " + "equal in LinkTensor.")); for (size_t i = 0; i < dst_vars.size(); ++i) { VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i]; AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback, @@ -136,7 +139,10 @@ class RecurrentBase : public framework::OperatorBase { const std::vector &dst_vars, Callback callback, bool is_backward = false) { - PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size()); + PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size(), + platform::errors::InvalidArgument( + "Sizes of source vars and destination vars are not " + "equal in LinkTensor.")); for (size_t i = 0; i < dst_vars.size(); ++i) { VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i]; AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback, @@ -159,7 +165,9 @@ class RecurrentBase : public framework::OperatorBase { if (is_backward && src_var == nullptr) { return; } - PADDLE_ENFORCE_NOT_NULL(src_var, "%s is not found.", src_var_name); + PADDLE_ENFORCE_NOT_NULL( + src_var, platform::errors::NotFound("Source variable %s is not found.", + src_var_name)); auto &src_tensor = src_var->Get(); auto *dst_var = dst_scope->Var(dst_var_name); @@ -178,9 +186,13 @@ class RecurrentBase : public framework::OperatorBase { return; } auto *src_var = src_scope.FindVar(src_var_name); - PADDLE_ENFORCE_NOT_NULL(src_var, "%s is not found.", src_var_name); + PADDLE_ENFORCE_NOT_NULL( + src_var, platform::errors::NotFound("Source variable %s is not found.", + src_var_name)); auto &src_tensor = src_var->Get(); - PADDLE_ENFORCE_NOT_NULL(dst_var, "%s is not found.", dst_var_name); + PADDLE_ENFORCE_NOT_NULL( + dst_var, platform::errors::NotFound( + "Destination variable %s is not found.", src_var_name)); auto *dst_tensor = dst_var->GetMutable(); callback(src_tensor, dst_tensor); } -- GitLab