From adfa5b835406dcf8fab8c385a31fa757a2d67ce0 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Fri, 21 Feb 2020 14:40:32 +0800 Subject: [PATCH] Add PADDLE_ENFORCE to Check Sequence Length of RecurrentOp (#22673) 1. Add PADDLE_ENFORCE to Check Sequence Length of RecurrentOp. 2. Also enrich PADDLE_ENFORCE error messages. --- paddle/fluid/operators/recurrent_op.cc | 88 ++++++++++++++++++-------- 1 file changed, 63 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/recurrent_op.cc b/paddle/fluid/operators/recurrent_op.cc index af6d62ea92..d3b4feb5b4 100644 --- a/paddle/fluid/operators/recurrent_op.cc +++ b/paddle/fluid/operators/recurrent_op.cc @@ -65,7 +65,8 @@ StepScopes::StepScopes(const platform::DeviceContext &dev_ctx, is_backward_(is_backward) { size_t num_step_scopes = is_train ? seq_len : 2; PADDLE_ENFORCE_EQ(is_train || !is_backward, true, - "Cannot backward when is not training"); + platform::errors::PreconditionNotMet( + "Cannot backward when is not training")); if (!is_backward_) { ClearStepScopes(dev_ctx, const_cast(&parent), scopes); scopes->reserve(static_cast(num_step_scopes)); @@ -85,7 +86,8 @@ framework::Scope &StepScopes::ExScope() { void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx, framework::Scope *parent_scope) { PADDLE_ENFORCE_EQ(is_backward_, true, - "Cannot get backward next scope when is forward"); + platform::errors::PreconditionNotMet( + "Cannot get backward next scope when is forward")); if (counter_ + 2 == scopes_->size()) { parent_scope->DeleteScope((*scopes_)[counter_ + 1]); scopes_->pop_back(); @@ -96,7 +98,8 @@ void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx, void StepScopes::ForwardNext() { PADDLE_ENFORCE_EQ(is_backward_, false, - "Cannot get forward next scope when is backward"); + platform::errors::PreconditionNotMet( + "Cannot get forward next scope when is backward")); ++counter_; } @@ -104,7 +107,10 @@ framework::Scope &StepScopes::GetScope(size_t scope_id) const { if (!is_train_) { scope_id %= 2; } - PADDLE_ENFORCE_LT(scope_id, scopes_->size()); + PADDLE_ENFORCE_LT( + scope_id, scopes_->size(), + platform::errors::InvalidArgument( + "Input scope_id is greater than scopes size in RecurrentOp")); return *(*scopes_)[scope_id]; } @@ -123,18 +129,33 @@ int64_t RecurrentBase::GetSequenceLength(const framework::Scope &scope) const { // Dim format SEQ_LEN, BATCH_SIZE, ... int64_t seq_len = -1; auto &all_inputs = Inputs(kInputs); - PADDLE_ENFORCE_EQ(all_inputs.empty(), false); + PADDLE_ENFORCE_EQ( + all_inputs.empty(), false, + platform::errors::InvalidArgument("RecurrentOp gets empty input")); for (auto &iname : all_inputs) { auto *var = scope.FindVar(iname); - PADDLE_ENFORCE_NOT_NULL(var); - PADDLE_ENFORCE_EQ(var->IsType(), true); + PADDLE_ENFORCE_NOT_NULL(var, + platform::errors::InvalidArgument( + "RecurrentOp finds var %s is NULL", iname)); + PADDLE_ENFORCE_EQ(var->IsType(), true, + platform::errors::InvalidArgument( + "RecurrentOp only accepts LoDTensor as input but " + "input var %s is not LoDTensor", + iname)); auto &dim = var->Get().dims(); if (seq_len == -1) { seq_len = dim[0]; } else { - PADDLE_ENFORCE_EQ(seq_len, dim[0]); + PADDLE_ENFORCE_EQ(seq_len, dim[0], + platform::errors::InvalidArgument( + "Sequence length of input %s in RecurrentOp is NOT " + "equal to sequence length of previous input", + iname)); } } + PADDLE_ENFORCE_GE(seq_len, 0, + platform::errors::InvalidArgument( + "RecurrentOp gets invalid sequence length.")); return seq_len; } @@ -260,7 +281,8 @@ StepScopes RecurrentOp::CreateStepScopes(const platform::DeviceContext &dev_ctx, const framework::Scope &scope, size_t seq_len) const { auto *var = scope.FindVar(Output(kStepScopes)); - PADDLE_ENFORCE_NOT_NULL(var); + PADDLE_ENFORCE_NOT_NULL(var, platform::errors::InvalidArgument( + "RecurrentOp gets empty StepScopes var")); return StepScopes(dev_ctx, scope, var->GetMutable(), Attr(kIsTrain), seq_len); } @@ -328,7 +350,10 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope, auto cur_state_grads = GradVarLists(Attr>(kStates)); - PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size()); + PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size(), + platform::errors::InvalidArgument( + "lengths of ex_states and cur_states are not " + "equal in RecurrentGradOp")); for (size_t i = 0; i < ex_state_grads.size(); ++i) { auto &cur_grad = cur_state_grads[i]; auto &ex_grad = ex_state_grads[i]; @@ -380,7 +405,10 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope, { auto &pg_names = Outputs(kParamGrads); auto &p_names = Inputs(kParameters); - PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size()); + PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size(), + platform::errors::InvalidArgument( + "Sizes of Parameters and ParamGrads are not equal " + "in RecurrentGradOp")); for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) { auto inside_grad_name = framework::GradVarName(p_names[param_id]); @@ -461,7 +489,9 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope, } // Delete the scope of StepScopes auto *var = scope.FindVar(Input(kStepScopes)); - PADDLE_ENFORCE_NOT_NULL(var); + PADDLE_ENFORCE_NOT_NULL(var, + platform::errors::InvalidArgument( + "StepScopes var is empty in RecurrentGradOp")); auto *step_scopes = var->GetMutable(); ClearStepScopes(dev_ctx, const_cast(&scope), step_scopes); } @@ -470,7 +500,9 @@ StepScopes RecurrentGradOp::CreateStepScopes( const platform::DeviceContext &dev_ctx, const framework::Scope &scope, size_t seq_len) const { auto *var = scope.FindVar(Input(kStepScopes)); - PADDLE_ENFORCE_NOT_NULL(var); + PADDLE_ENFORCE_NOT_NULL(var, + platform::errors::InvalidArgument( + "StepScopes var is empty in RecurrentGradOp")); return StepScopes(dev_ctx, scope, var->GetMutable(), Attr(kIsTrain), seq_len, true /*is_backward*/); } @@ -619,20 +651,24 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase { ctx->Attrs() .Get>(RecurrentBase::kExStates) .size(), - 0, "The Attr(%s) should be empty.", RecurrentBase::kExStates); + 0, platform::errors::InvalidArgument("The Attr(%s) should be empty.", + RecurrentBase::kExStates)); PADDLE_ENFORCE_EQ( ctx->Attrs() .Get>(RecurrentBase::kStates) .size(), - 0, "The Attr(%s) should be empty.", RecurrentBase::kStates); + 0, platform::errors::InvalidArgument("The Attr(%s) should be empty.", + RecurrentBase::kStates)); } - PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kInputs), true, - "The input(%s) should not be empty.", - RecurrentBase::kInputs); - PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kOutputs), true, - "The input(%s) should not be empty.", - RecurrentBase::kOutputs); + PADDLE_ENFORCE_EQ( + ctx->HasInputs(RecurrentBase::kInputs), true, + platform::errors::InvalidArgument("The input(%s) should not be empty.", + RecurrentBase::kInputs)); + PADDLE_ENFORCE_EQ( + ctx->HasInputs(RecurrentBase::kOutputs), true, + platform::errors::InvalidArgument("The input(%s) should not be empty.", + RecurrentBase::kOutputs)); // In some case the kInitialStates is empty. if (ctx->HasInputs(RecurrentBase::kInitialStates) && @@ -644,8 +680,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase { PADDLE_ENFORCE_EQ( ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), true, - "The output of(%s) should not be empty.", - framework::GradVarName(RecurrentBase::kInputs)); + platform::errors::InvalidArgument( + "The output of(%s) should not be empty.", + framework::GradVarName(RecurrentBase::kInputs))); ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInputs), ctx->GetInputsDim(RecurrentBase::kInputs)); @@ -653,8 +690,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase { if (ctx->HasInputs(RecurrentBase::kParameters)) { PADDLE_ENFORCE_EQ( ctx->HasOutputs(framework::GradVarName(RecurrentBase::kParameters)), - true, "The output of(%s) should not be empty.", - framework::GradVarName(RecurrentBase::kParameters)); + true, platform::errors::InvalidArgument( + "The output of(%s) should not be empty.", + framework::GradVarName(RecurrentBase::kParameters))); ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kParameters), ctx->GetInputsDim(RecurrentBase::kParameters)); } -- GitLab