未验证 提交 adfa5b83 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add PADDLE_ENFORCE to Check Sequence Length of RecurrentOp (#22673)

1. Add PADDLE_ENFORCE to Check Sequence Length of RecurrentOp.
2. Also enrich PADDLE_ENFORCE error messages.
上级 769c032f
......@@ -65,7 +65,8 @@ StepScopes::StepScopes(const platform::DeviceContext &dev_ctx,
is_backward_(is_backward) {
size_t num_step_scopes = is_train ? seq_len : 2;
PADDLE_ENFORCE_EQ(is_train || !is_backward, true,
"Cannot backward when is not training");
platform::errors::PreconditionNotMet(
"Cannot backward when is not training"));
if (!is_backward_) {
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&parent), scopes);
scopes->reserve(static_cast<size_t>(num_step_scopes));
......@@ -85,7 +86,8 @@ framework::Scope &StepScopes::ExScope() {
void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx,
framework::Scope *parent_scope) {
PADDLE_ENFORCE_EQ(is_backward_, true,
"Cannot get backward next scope when is forward");
platform::errors::PreconditionNotMet(
"Cannot get backward next scope when is forward"));
if (counter_ + 2 == scopes_->size()) {
parent_scope->DeleteScope((*scopes_)[counter_ + 1]);
scopes_->pop_back();
......@@ -96,7 +98,8 @@ void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx,
void StepScopes::ForwardNext() {
PADDLE_ENFORCE_EQ(is_backward_, false,
"Cannot get forward next scope when is backward");
platform::errors::PreconditionNotMet(
"Cannot get forward next scope when is backward"));
++counter_;
}
......@@ -104,7 +107,10 @@ framework::Scope &StepScopes::GetScope(size_t scope_id) const {
if (!is_train_) {
scope_id %= 2;
}
PADDLE_ENFORCE_LT(scope_id, scopes_->size());
PADDLE_ENFORCE_LT(
scope_id, scopes_->size(),
platform::errors::InvalidArgument(
"Input scope_id is greater than scopes size in RecurrentOp"));
return *(*scopes_)[scope_id];
}
......@@ -123,18 +129,33 @@ int64_t RecurrentBase::GetSequenceLength(const framework::Scope &scope) const {
// Dim format SEQ_LEN, BATCH_SIZE, ...
int64_t seq_len = -1;
auto &all_inputs = Inputs(kInputs);
PADDLE_ENFORCE_EQ(all_inputs.empty(), false);
PADDLE_ENFORCE_EQ(
all_inputs.empty(), false,
platform::errors::InvalidArgument("RecurrentOp gets empty input"));
for (auto &iname : all_inputs) {
auto *var = scope.FindVar(iname);
PADDLE_ENFORCE_NOT_NULL(var);
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true);
PADDLE_ENFORCE_NOT_NULL(var,
platform::errors::InvalidArgument(
"RecurrentOp finds var %s is NULL", iname));
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"RecurrentOp only accepts LoDTensor as input but "
"input var %s is not LoDTensor",
iname));
auto &dim = var->Get<framework::LoDTensor>().dims();
if (seq_len == -1) {
seq_len = dim[0];
} else {
PADDLE_ENFORCE_EQ(seq_len, dim[0]);
PADDLE_ENFORCE_EQ(seq_len, dim[0],
platform::errors::InvalidArgument(
"Sequence length of input %s in RecurrentOp is NOT "
"equal to sequence length of previous input",
iname));
}
}
PADDLE_ENFORCE_GE(seq_len, 0,
platform::errors::InvalidArgument(
"RecurrentOp gets invalid sequence length."));
return seq_len;
}
......@@ -260,7 +281,8 @@ StepScopes RecurrentOp::CreateStepScopes(const platform::DeviceContext &dev_ctx,
const framework::Scope &scope,
size_t seq_len) const {
auto *var = scope.FindVar(Output(kStepScopes));
PADDLE_ENFORCE_NOT_NULL(var);
PADDLE_ENFORCE_NOT_NULL(var, platform::errors::InvalidArgument(
"RecurrentOp gets empty StepScopes var"));
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len);
}
......@@ -328,7 +350,10 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
auto cur_state_grads =
GradVarLists(Attr<std::vector<std::string>>(kStates));
PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size());
PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size(),
platform::errors::InvalidArgument(
"lengths of ex_states and cur_states are not "
"equal in RecurrentGradOp"));
for (size_t i = 0; i < ex_state_grads.size(); ++i) {
auto &cur_grad = cur_state_grads[i];
auto &ex_grad = ex_state_grads[i];
......@@ -380,7 +405,10 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
{
auto &pg_names = Outputs(kParamGrads);
auto &p_names = Inputs(kParameters);
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size());
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size(),
platform::errors::InvalidArgument(
"Sizes of Parameters and ParamGrads are not equal "
"in RecurrentGradOp"));
for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) {
auto inside_grad_name = framework::GradVarName(p_names[param_id]);
......@@ -461,7 +489,9 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
}
// Delete the scope of StepScopes
auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE_NOT_NULL(var);
PADDLE_ENFORCE_NOT_NULL(var,
platform::errors::InvalidArgument(
"StepScopes var is empty in RecurrentGradOp"));
auto *step_scopes = var->GetMutable<StepScopeVar>();
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&scope), step_scopes);
}
......@@ -470,7 +500,9 @@ StepScopes RecurrentGradOp::CreateStepScopes(
const platform::DeviceContext &dev_ctx, const framework::Scope &scope,
size_t seq_len) const {
auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE_NOT_NULL(var);
PADDLE_ENFORCE_NOT_NULL(var,
platform::errors::InvalidArgument(
"StepScopes var is empty in RecurrentGradOp"));
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len, true /*is_backward*/);
}
......@@ -619,20 +651,24 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
ctx->Attrs()
.Get<std::vector<std::string>>(RecurrentBase::kExStates)
.size(),
0, "The Attr(%s) should be empty.", RecurrentBase::kExStates);
0, platform::errors::InvalidArgument("The Attr(%s) should be empty.",
RecurrentBase::kExStates));
PADDLE_ENFORCE_EQ(
ctx->Attrs()
.Get<std::vector<std::string>>(RecurrentBase::kStates)
.size(),
0, "The Attr(%s) should be empty.", RecurrentBase::kStates);
0, platform::errors::InvalidArgument("The Attr(%s) should be empty.",
RecurrentBase::kStates));
}
PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kInputs), true,
"The input(%s) should not be empty.",
RecurrentBase::kInputs);
PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kOutputs), true,
"The input(%s) should not be empty.",
RecurrentBase::kOutputs);
PADDLE_ENFORCE_EQ(
ctx->HasInputs(RecurrentBase::kInputs), true,
platform::errors::InvalidArgument("The input(%s) should not be empty.",
RecurrentBase::kInputs));
PADDLE_ENFORCE_EQ(
ctx->HasInputs(RecurrentBase::kOutputs), true,
platform::errors::InvalidArgument("The input(%s) should not be empty.",
RecurrentBase::kOutputs));
// In some case the kInitialStates is empty.
if (ctx->HasInputs(RecurrentBase::kInitialStates) &&
......@@ -644,8 +680,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
PADDLE_ENFORCE_EQ(
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), true,
platform::errors::InvalidArgument(
"The output of(%s) should not be empty.",
framework::GradVarName(RecurrentBase::kInputs));
framework::GradVarName(RecurrentBase::kInputs)));
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInputs),
ctx->GetInputsDim(RecurrentBase::kInputs));
......@@ -653,8 +690,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
if (ctx->HasInputs(RecurrentBase::kParameters)) {
PADDLE_ENFORCE_EQ(
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kParameters)),
true, "The output of(%s) should not be empty.",
framework::GradVarName(RecurrentBase::kParameters));
true, platform::errors::InvalidArgument(
"The output of(%s) should not be empty.",
framework::GradVarName(RecurrentBase::kParameters)));
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kParameters),
ctx->GetInputsDim(RecurrentBase::kParameters));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册