未验证 提交 71b5f1d2 编写于 作者: H Huihuang Zheng 提交者: GitHub

OP (recurrent) error message enhancement (#23481)

* OP (recurrent) error message enhancement
上级 8674a82c
...@@ -72,8 +72,12 @@ static void FindAllOpAndGradOp(const framework::ProgramDesc &program, ...@@ -72,8 +72,12 @@ static void FindAllOpAndGradOp(const framework::ProgramDesc &program,
OpVariantSet &ops = op_and_grad_op->first; OpVariantSet &ops = op_and_grad_op->first;
OpVariantSet &grad_ops = op_and_grad_op->second; OpVariantSet &grad_ops = op_and_grad_op->second;
PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(), PADDLE_ENFORCE_GE(
"There are extra grad ops in the graph or program"); ops.size(), grad_ops.size(),
platform::errors::InvalidArgument(
"There are more grad ops than forward ops in the graph or program, "
"the number of ops is %d and the number of grad_ops is %d.",
ops.size(), grad_ops.size()));
for (size_t i = 1; i < program.Size(); ++i) { for (size_t i = 1; i < program.Size(); ++i) {
auto &block = program.Block(i); auto &block = program.Block(i);
...@@ -87,8 +91,12 @@ static void FindAllOpAndGradOp(const framework::ProgramDesc &program, ...@@ -87,8 +91,12 @@ static void FindAllOpAndGradOp(const framework::ProgramDesc &program,
} }
} }
PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(), PADDLE_ENFORCE_GE(
"There are extra grad ops in the graph or program"); ops.size(), grad_ops.size(),
platform::errors::InvalidArgument(
"There are more grad ops than forward ops in the graph or program, "
"the number of ops is %d and the number of grad_ops is %d.",
ops.size(), grad_ops.size()));
} }
// Returns GradVarName of input var names // Returns GradVarName of input var names
...@@ -169,7 +177,11 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr( ...@@ -169,7 +177,11 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
fwd_input.size(), in_grads.size(), fwd_input.size(), in_grads.size(),
"Backward input gradient number does not match forward input number."); platform::errors::PreconditionNotMet(
"Backward input gradient number does not match forward "
"input number. The number of forward input number is %d and the "
"number of backward input gradient number is %d.",
fwd_input.size(), in_grads.size()));
for (size_t i = 0; i < in_grads.size(); ++i) { for (size_t i = 0; i < in_grads.size(); ++i) {
if (in_grads[i] == framework::kEmptyVarName) { if (in_grads[i] == framework::kEmptyVarName) {
continue; continue;
...@@ -181,9 +193,13 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr( ...@@ -181,9 +193,13 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr(
auto &fwd_param = fwd_op.Inputs().at(RecurrentBase::kParameters); auto &fwd_param = fwd_op.Inputs().at(RecurrentBase::kParameters);
auto &param_grads = auto &param_grads =
bwd_op.Outputs().at(framework::GradVarName(RecurrentBase::kParameters)); bwd_op.Outputs().at(framework::GradVarName(RecurrentBase::kParameters));
PADDLE_ENFORCE_EQ(fwd_param.size(), param_grads.size(), PADDLE_ENFORCE_EQ(
"Backward parameter gradient number does not match forward " fwd_param.size(), param_grads.size(),
"parameter number."); platform::errors::PreconditionNotMet(
"Backward parameter gradient number does not match "
"forward parameter number. The number of forward parameter number is "
"%d and the number of backward parameter gradient is %d.",
fwd_param.size(), param_grads.size()));
for (size_t i = 0; i < fwd_param.size(); ++i) { for (size_t i = 0; i < fwd_param.size(); ++i) {
if (param_grads[i] == framework::kEmptyVarName) { if (param_grads[i] == framework::kEmptyVarName) {
continue; continue;
...@@ -241,12 +257,16 @@ void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( ...@@ -241,12 +257,16 @@ void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
const OpVariant *matched_fwd_op = nullptr; const OpVariant *matched_fwd_op = nullptr;
for (auto &fwd_op : recurrent_ops) { for (auto &fwd_op : recurrent_ops) {
if (IsMatchedRecurrentOpAndRecurrentGradOp(fwd_op, bwd_op)) { if (IsMatchedRecurrentOpAndRecurrentGradOp(fwd_op, bwd_op)) {
PADDLE_ENFORCE(matched_fwd_op == nullptr, PADDLE_ENFORCE_EQ(matched_fwd_op, nullptr,
"Found multiple matched recurrent op"); platform::errors::PreconditionNotMet(
"Found multiple recurrent forward op matches "
"recurrent grad op."));
matched_fwd_op = &fwd_op; matched_fwd_op = &fwd_op;
} }
} }
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op, "Cannot find matched forward op"); PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
platform::errors::PreconditionNotMet(
"Cannot find matched forward op."));
SetRecurrentOpAndRecurrentGradOpSkipVarAttr(*matched_fwd_op, bwd_op); SetRecurrentOpAndRecurrentGradOpSkipVarAttr(*matched_fwd_op, bwd_op);
recurrent_ops.erase(*matched_fwd_op); recurrent_ops.erase(*matched_fwd_op);
} }
......
...@@ -118,7 +118,10 @@ class RecurrentBase : public framework::OperatorBase { ...@@ -118,7 +118,10 @@ class RecurrentBase : public framework::OperatorBase {
const std::vector<std::string> &dst_vars, const std::vector<std::string> &dst_vars,
Callback callback, Callback callback,
bool is_backward = false) { bool is_backward = false) {
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size()); PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size(),
platform::errors::InvalidArgument(
"Sizes of source vars and destination vars are not "
"equal in LinkTensor."));
for (size_t i = 0; i < dst_vars.size(); ++i) { for (size_t i = 0; i < dst_vars.size(); ++i) {
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i]; VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback, AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
...@@ -136,7 +139,10 @@ class RecurrentBase : public framework::OperatorBase { ...@@ -136,7 +139,10 @@ class RecurrentBase : public framework::OperatorBase {
const std::vector<std::string> &dst_vars, const std::vector<std::string> &dst_vars,
Callback callback, Callback callback,
bool is_backward = false) { bool is_backward = false) {
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size()); PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size(),
platform::errors::InvalidArgument(
"Sizes of source vars and destination vars are not "
"equal in LinkTensor."));
for (size_t i = 0; i < dst_vars.size(); ++i) { for (size_t i = 0; i < dst_vars.size(); ++i) {
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i]; VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback, AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
...@@ -159,7 +165,9 @@ class RecurrentBase : public framework::OperatorBase { ...@@ -159,7 +165,9 @@ class RecurrentBase : public framework::OperatorBase {
if (is_backward && src_var == nullptr) { if (is_backward && src_var == nullptr) {
return; return;
} }
PADDLE_ENFORCE_NOT_NULL(src_var, "%s is not found.", src_var_name); PADDLE_ENFORCE_NOT_NULL(
src_var, platform::errors::NotFound("Source variable %s is not found.",
src_var_name));
auto &src_tensor = src_var->Get<framework::LoDTensor>(); auto &src_tensor = src_var->Get<framework::LoDTensor>();
auto *dst_var = dst_scope->Var(dst_var_name); auto *dst_var = dst_scope->Var(dst_var_name);
...@@ -178,9 +186,13 @@ class RecurrentBase : public framework::OperatorBase { ...@@ -178,9 +186,13 @@ class RecurrentBase : public framework::OperatorBase {
return; return;
} }
auto *src_var = src_scope.FindVar(src_var_name); auto *src_var = src_scope.FindVar(src_var_name);
PADDLE_ENFORCE_NOT_NULL(src_var, "%s is not found.", src_var_name); PADDLE_ENFORCE_NOT_NULL(
src_var, platform::errors::NotFound("Source variable %s is not found.",
src_var_name));
auto &src_tensor = src_var->Get<framework::LoDTensor>(); auto &src_tensor = src_var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(dst_var, "%s is not found.", dst_var_name); PADDLE_ENFORCE_NOT_NULL(
dst_var, platform::errors::NotFound(
"Destination variable %s is not found.", src_var_name));
auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>(); auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>();
callback(src_tensor, dst_tensor); callback(src_tensor, dst_tensor);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册