未验证 提交 71b5f1d2 编写于 作者: H Huihuang Zheng 提交者: GitHub

OP (recurrent) error message enhancement (#23481)

* OP (recurrent) error message enhancement
上级 8674a82c
......@@ -72,8 +72,12 @@ static void FindAllOpAndGradOp(const framework::ProgramDesc &program,
OpVariantSet &ops = op_and_grad_op->first;
OpVariantSet &grad_ops = op_and_grad_op->second;
PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(),
"There are extra grad ops in the graph or program");
PADDLE_ENFORCE_GE(
ops.size(), grad_ops.size(),
platform::errors::InvalidArgument(
"There are more grad ops than forward ops in the graph or program, "
"the number of ops is %d and the number of grad_ops is %d.",
ops.size(), grad_ops.size()));
for (size_t i = 1; i < program.Size(); ++i) {
auto &block = program.Block(i);
......@@ -87,8 +91,12 @@ static void FindAllOpAndGradOp(const framework::ProgramDesc &program,
}
}
PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(),
"There are extra grad ops in the graph or program");
PADDLE_ENFORCE_GE(
ops.size(), grad_ops.size(),
platform::errors::InvalidArgument(
"There are more grad ops than forward ops in the graph or program, "
"the number of ops is %d and the number of grad_ops is %d.",
ops.size(), grad_ops.size()));
}
// Returns GradVarName of input var names
......@@ -169,7 +177,11 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr(
PADDLE_ENFORCE_EQ(
fwd_input.size(), in_grads.size(),
"Backward input gradient number does not match forward input number.");
platform::errors::PreconditionNotMet(
"Backward input gradient number does not match forward "
"input number. The number of forward input number is %d and the "
"number of backward input gradient number is %d.",
fwd_input.size(), in_grads.size()));
for (size_t i = 0; i < in_grads.size(); ++i) {
if (in_grads[i] == framework::kEmptyVarName) {
continue;
......@@ -181,9 +193,13 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr(
auto &fwd_param = fwd_op.Inputs().at(RecurrentBase::kParameters);
auto &param_grads =
bwd_op.Outputs().at(framework::GradVarName(RecurrentBase::kParameters));
PADDLE_ENFORCE_EQ(fwd_param.size(), param_grads.size(),
"Backward parameter gradient number does not match forward "
"parameter number.");
PADDLE_ENFORCE_EQ(
fwd_param.size(), param_grads.size(),
platform::errors::PreconditionNotMet(
"Backward parameter gradient number does not match "
"forward parameter number. The number of forward parameter number is "
"%d and the number of backward parameter gradient is %d.",
fwd_param.size(), param_grads.size()));
for (size_t i = 0; i < fwd_param.size(); ++i) {
if (param_grads[i] == framework::kEmptyVarName) {
continue;
......@@ -241,12 +257,16 @@ void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
const OpVariant *matched_fwd_op = nullptr;
for (auto &fwd_op : recurrent_ops) {
if (IsMatchedRecurrentOpAndRecurrentGradOp(fwd_op, bwd_op)) {
PADDLE_ENFORCE(matched_fwd_op == nullptr,
"Found multiple matched recurrent op");
PADDLE_ENFORCE_EQ(matched_fwd_op, nullptr,
platform::errors::PreconditionNotMet(
"Found multiple recurrent forward op matches "
"recurrent grad op."));
matched_fwd_op = &fwd_op;
}
}
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op, "Cannot find matched forward op");
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
platform::errors::PreconditionNotMet(
"Cannot find matched forward op."));
SetRecurrentOpAndRecurrentGradOpSkipVarAttr(*matched_fwd_op, bwd_op);
recurrent_ops.erase(*matched_fwd_op);
}
......
......@@ -118,7 +118,10 @@ class RecurrentBase : public framework::OperatorBase {
const std::vector<std::string> &dst_vars,
Callback callback,
bool is_backward = false) {
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size(),
platform::errors::InvalidArgument(
"Sizes of source vars and destination vars are not "
"equal in LinkTensor."));
for (size_t i = 0; i < dst_vars.size(); ++i) {
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
......@@ -136,7 +139,10 @@ class RecurrentBase : public framework::OperatorBase {
const std::vector<std::string> &dst_vars,
Callback callback,
bool is_backward = false) {
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size(),
platform::errors::InvalidArgument(
"Sizes of source vars and destination vars are not "
"equal in LinkTensor."));
for (size_t i = 0; i < dst_vars.size(); ++i) {
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
......@@ -159,7 +165,9 @@ class RecurrentBase : public framework::OperatorBase {
if (is_backward && src_var == nullptr) {
return;
}
PADDLE_ENFORCE_NOT_NULL(src_var, "%s is not found.", src_var_name);
PADDLE_ENFORCE_NOT_NULL(
src_var, platform::errors::NotFound("Source variable %s is not found.",
src_var_name));
auto &src_tensor = src_var->Get<framework::LoDTensor>();
auto *dst_var = dst_scope->Var(dst_var_name);
......@@ -178,9 +186,13 @@ class RecurrentBase : public framework::OperatorBase {
return;
}
auto *src_var = src_scope.FindVar(src_var_name);
PADDLE_ENFORCE_NOT_NULL(src_var, "%s is not found.", src_var_name);
PADDLE_ENFORCE_NOT_NULL(
src_var, platform::errors::NotFound("Source variable %s is not found.",
src_var_name));
auto &src_tensor = src_var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(dst_var, "%s is not found.", dst_var_name);
PADDLE_ENFORCE_NOT_NULL(
dst_var, platform::errors::NotFound(
"Destination variable %s is not found.", src_var_name));
auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>();
callback(src_tensor, dst_tensor);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册