未验证 提交 adfa5b83 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add PADDLE_ENFORCE to Check Sequence Length of RecurrentOp (#22673)

1. Add PADDLE_ENFORCE to Check Sequence Length of RecurrentOp.
2. Also enrich PADDLE_ENFORCE error messages.
上级 769c032f
...@@ -65,7 +65,8 @@ StepScopes::StepScopes(const platform::DeviceContext &dev_ctx, ...@@ -65,7 +65,8 @@ StepScopes::StepScopes(const platform::DeviceContext &dev_ctx,
is_backward_(is_backward) { is_backward_(is_backward) {
size_t num_step_scopes = is_train ? seq_len : 2; size_t num_step_scopes = is_train ? seq_len : 2;
PADDLE_ENFORCE_EQ(is_train || !is_backward, true, PADDLE_ENFORCE_EQ(is_train || !is_backward, true,
"Cannot backward when is not training"); platform::errors::PreconditionNotMet(
"Cannot backward when is not training"));
if (!is_backward_) { if (!is_backward_) {
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&parent), scopes); ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&parent), scopes);
scopes->reserve(static_cast<size_t>(num_step_scopes)); scopes->reserve(static_cast<size_t>(num_step_scopes));
...@@ -85,7 +86,8 @@ framework::Scope &StepScopes::ExScope() { ...@@ -85,7 +86,8 @@ framework::Scope &StepScopes::ExScope() {
void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx, void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx,
framework::Scope *parent_scope) { framework::Scope *parent_scope) {
PADDLE_ENFORCE_EQ(is_backward_, true, PADDLE_ENFORCE_EQ(is_backward_, true,
"Cannot get backward next scope when is forward"); platform::errors::PreconditionNotMet(
"Cannot get backward next scope when is forward"));
if (counter_ + 2 == scopes_->size()) { if (counter_ + 2 == scopes_->size()) {
parent_scope->DeleteScope((*scopes_)[counter_ + 1]); parent_scope->DeleteScope((*scopes_)[counter_ + 1]);
scopes_->pop_back(); scopes_->pop_back();
...@@ -96,7 +98,8 @@ void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx, ...@@ -96,7 +98,8 @@ void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx,
void StepScopes::ForwardNext() { void StepScopes::ForwardNext() {
PADDLE_ENFORCE_EQ(is_backward_, false, PADDLE_ENFORCE_EQ(is_backward_, false,
"Cannot get forward next scope when is backward"); platform::errors::PreconditionNotMet(
"Cannot get forward next scope when is backward"));
++counter_; ++counter_;
} }
...@@ -104,7 +107,10 @@ framework::Scope &StepScopes::GetScope(size_t scope_id) const { ...@@ -104,7 +107,10 @@ framework::Scope &StepScopes::GetScope(size_t scope_id) const {
if (!is_train_) { if (!is_train_) {
scope_id %= 2; scope_id %= 2;
} }
PADDLE_ENFORCE_LT(scope_id, scopes_->size()); PADDLE_ENFORCE_LT(
scope_id, scopes_->size(),
platform::errors::InvalidArgument(
"Input scope_id is greater than scopes size in RecurrentOp"));
return *(*scopes_)[scope_id]; return *(*scopes_)[scope_id];
} }
...@@ -123,18 +129,33 @@ int64_t RecurrentBase::GetSequenceLength(const framework::Scope &scope) const { ...@@ -123,18 +129,33 @@ int64_t RecurrentBase::GetSequenceLength(const framework::Scope &scope) const {
// Dim format SEQ_LEN, BATCH_SIZE, ... // Dim format SEQ_LEN, BATCH_SIZE, ...
int64_t seq_len = -1; int64_t seq_len = -1;
auto &all_inputs = Inputs(kInputs); auto &all_inputs = Inputs(kInputs);
PADDLE_ENFORCE_EQ(all_inputs.empty(), false); PADDLE_ENFORCE_EQ(
all_inputs.empty(), false,
platform::errors::InvalidArgument("RecurrentOp gets empty input"));
for (auto &iname : all_inputs) { for (auto &iname : all_inputs) {
auto *var = scope.FindVar(iname); auto *var = scope.FindVar(iname);
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var,
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true); platform::errors::InvalidArgument(
"RecurrentOp finds var %s is NULL", iname));
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"RecurrentOp only accepts LoDTensor as input but "
"input var %s is not LoDTensor",
iname));
auto &dim = var->Get<framework::LoDTensor>().dims(); auto &dim = var->Get<framework::LoDTensor>().dims();
if (seq_len == -1) { if (seq_len == -1) {
seq_len = dim[0]; seq_len = dim[0];
} else { } else {
PADDLE_ENFORCE_EQ(seq_len, dim[0]); PADDLE_ENFORCE_EQ(seq_len, dim[0],
platform::errors::InvalidArgument(
"Sequence length of input %s in RecurrentOp is NOT "
"equal to sequence length of previous input",
iname));
} }
} }
PADDLE_ENFORCE_GE(seq_len, 0,
platform::errors::InvalidArgument(
"RecurrentOp gets invalid sequence length."));
return seq_len; return seq_len;
} }
...@@ -260,7 +281,8 @@ StepScopes RecurrentOp::CreateStepScopes(const platform::DeviceContext &dev_ctx, ...@@ -260,7 +281,8 @@ StepScopes RecurrentOp::CreateStepScopes(const platform::DeviceContext &dev_ctx,
const framework::Scope &scope, const framework::Scope &scope,
size_t seq_len) const { size_t seq_len) const {
auto *var = scope.FindVar(Output(kStepScopes)); auto *var = scope.FindVar(Output(kStepScopes));
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var, platform::errors::InvalidArgument(
"RecurrentOp gets empty StepScopes var"));
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(), return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len); Attr<bool>(kIsTrain), seq_len);
} }
...@@ -328,7 +350,10 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope, ...@@ -328,7 +350,10 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
auto cur_state_grads = auto cur_state_grads =
GradVarLists(Attr<std::vector<std::string>>(kStates)); GradVarLists(Attr<std::vector<std::string>>(kStates));
PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size()); PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size(),
platform::errors::InvalidArgument(
"lengths of ex_states and cur_states are not "
"equal in RecurrentGradOp"));
for (size_t i = 0; i < ex_state_grads.size(); ++i) { for (size_t i = 0; i < ex_state_grads.size(); ++i) {
auto &cur_grad = cur_state_grads[i]; auto &cur_grad = cur_state_grads[i];
auto &ex_grad = ex_state_grads[i]; auto &ex_grad = ex_state_grads[i];
...@@ -380,7 +405,10 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope, ...@@ -380,7 +405,10 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
{ {
auto &pg_names = Outputs(kParamGrads); auto &pg_names = Outputs(kParamGrads);
auto &p_names = Inputs(kParameters); auto &p_names = Inputs(kParameters);
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size()); PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size(),
platform::errors::InvalidArgument(
"Sizes of Parameters and ParamGrads are not equal "
"in RecurrentGradOp"));
for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) { for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) {
auto inside_grad_name = framework::GradVarName(p_names[param_id]); auto inside_grad_name = framework::GradVarName(p_names[param_id]);
...@@ -461,7 +489,9 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope, ...@@ -461,7 +489,9 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
} }
// Delete the scope of StepScopes // Delete the scope of StepScopes
auto *var = scope.FindVar(Input(kStepScopes)); auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var,
platform::errors::InvalidArgument(
"StepScopes var is empty in RecurrentGradOp"));
auto *step_scopes = var->GetMutable<StepScopeVar>(); auto *step_scopes = var->GetMutable<StepScopeVar>();
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&scope), step_scopes); ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&scope), step_scopes);
} }
...@@ -470,7 +500,9 @@ StepScopes RecurrentGradOp::CreateStepScopes( ...@@ -470,7 +500,9 @@ StepScopes RecurrentGradOp::CreateStepScopes(
const platform::DeviceContext &dev_ctx, const framework::Scope &scope, const platform::DeviceContext &dev_ctx, const framework::Scope &scope,
size_t seq_len) const { size_t seq_len) const {
auto *var = scope.FindVar(Input(kStepScopes)); auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var,
platform::errors::InvalidArgument(
"StepScopes var is empty in RecurrentGradOp"));
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(), return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len, true /*is_backward*/); Attr<bool>(kIsTrain), seq_len, true /*is_backward*/);
} }
...@@ -619,20 +651,24 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase { ...@@ -619,20 +651,24 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
ctx->Attrs() ctx->Attrs()
.Get<std::vector<std::string>>(RecurrentBase::kExStates) .Get<std::vector<std::string>>(RecurrentBase::kExStates)
.size(), .size(),
0, "The Attr(%s) should be empty.", RecurrentBase::kExStates); 0, platform::errors::InvalidArgument("The Attr(%s) should be empty.",
RecurrentBase::kExStates));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->Attrs() ctx->Attrs()
.Get<std::vector<std::string>>(RecurrentBase::kStates) .Get<std::vector<std::string>>(RecurrentBase::kStates)
.size(), .size(),
0, "The Attr(%s) should be empty.", RecurrentBase::kStates); 0, platform::errors::InvalidArgument("The Attr(%s) should be empty.",
RecurrentBase::kStates));
} }
PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kInputs), true, PADDLE_ENFORCE_EQ(
"The input(%s) should not be empty.", ctx->HasInputs(RecurrentBase::kInputs), true,
RecurrentBase::kInputs); platform::errors::InvalidArgument("The input(%s) should not be empty.",
PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kOutputs), true, RecurrentBase::kInputs));
"The input(%s) should not be empty.", PADDLE_ENFORCE_EQ(
RecurrentBase::kOutputs); ctx->HasInputs(RecurrentBase::kOutputs), true,
platform::errors::InvalidArgument("The input(%s) should not be empty.",
RecurrentBase::kOutputs));
// In some case the kInitialStates is empty. // In some case the kInitialStates is empty.
if (ctx->HasInputs(RecurrentBase::kInitialStates) && if (ctx->HasInputs(RecurrentBase::kInitialStates) &&
...@@ -644,8 +680,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase { ...@@ -644,8 +680,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), true, ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), true,
"The output of(%s) should not be empty.", platform::errors::InvalidArgument(
framework::GradVarName(RecurrentBase::kInputs)); "The output of(%s) should not be empty.",
framework::GradVarName(RecurrentBase::kInputs)));
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInputs), ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInputs),
ctx->GetInputsDim(RecurrentBase::kInputs)); ctx->GetInputsDim(RecurrentBase::kInputs));
...@@ -653,8 +690,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase { ...@@ -653,8 +690,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
if (ctx->HasInputs(RecurrentBase::kParameters)) { if (ctx->HasInputs(RecurrentBase::kParameters)) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kParameters)), ctx->HasOutputs(framework::GradVarName(RecurrentBase::kParameters)),
true, "The output of(%s) should not be empty.", true, platform::errors::InvalidArgument(
framework::GradVarName(RecurrentBase::kParameters)); "The output of(%s) should not be empty.",
framework::GradVarName(RecurrentBase::kParameters)));
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kParameters), ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kParameters),
ctx->GetInputsDim(RecurrentBase::kParameters)); ctx->GetInputsDim(RecurrentBase::kParameters));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册