diff --git a/paddle/fluid/operators/recurrent_op.cc b/paddle/fluid/operators/recurrent_op.cc index 57db82505624914bb010ca041e8315c4cf9d5ad5..efc3ba056f1a2151ddfa96305cca7dc2bc73f2f6 100644 --- a/paddle/fluid/operators/recurrent_op.cc +++ b/paddle/fluid/operators/recurrent_op.cc @@ -54,20 +54,6 @@ static void ClearStepScopes(const platform::DeviceContext &dev_ctx, step_scopes->clear(); } -// StepScopes manages scopes inside RNN. -// StepScopes::CurScope() get the current scope -// StepScopes::ExScope() get the ex-scope, or scope in previous time step. -// StepScopes::Next() move to next time step. -// -// if is_train = False, then -// there are two scopes for the RNN and just support forward. -// else -// the len(scopes) == seq_len -// -// if is_backward = True, then -// reversely access scopes -// else -// access scopes from begin to end. StepScopes::StepScopes(const platform::DeviceContext &dev_ctx, const framework::Scope &parent, StepScopeVar *scopes, bool is_train, size_t seq_len, bool is_backward) @@ -76,8 +62,8 @@ StepScopes::StepScopes(const platform::DeviceContext &dev_ctx, is_train_(is_train), is_backward_(is_backward) { size_t num_step_scopes = is_train ? seq_len : 2; - PADDLE_ENFORCE(is_train || !is_backward, - "Cannot backward when is not training"); + PADDLE_ENFORCE_EQ(is_train || !is_backward, true, + "Cannot backward when is not training"); if (!is_backward_) { ClearStepScopes(dev_ctx, const_cast(&parent), scopes); scopes->reserve(static_cast(num_step_scopes)); @@ -94,12 +80,22 @@ framework::Scope &StepScopes::ExScope() { return scope; } -void StepScopes::Next() { - if (is_backward_) { - --counter_; - } else { - ++counter_; +void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx, + framework::Scope *parent_scope) { + PADDLE_ENFORCE_EQ(is_backward_, true, + "Cannot get backward next scope when is forward"); + if (counter_ + 2 == scopes_->size()) { + parent_scope->DeleteScope((*scopes_)[counter_ + 1]); + scopes_->pop_back(); + VLOG(3) << "Deleted scope at " << counter_ + 1; } + --counter_; +} + +void StepScopes::ForwardNext() { + PADDLE_ENFORCE_EQ(is_backward_, false, + "Cannot get forward next scope when is backward"); + ++counter_; } framework::Scope &StepScopes::GetScope(size_t scope_id) const { @@ -125,11 +121,11 @@ int64_t RecurrentBase::GetSequenceLength(const framework::Scope &scope) const { // Dim format SEQ_LEN, BATCH_SIZE, ... int64_t seq_len = -1; auto &all_inputs = Inputs(kInputs); - PADDLE_ENFORCE(!all_inputs.empty()); + PADDLE_ENFORCE_EQ(!all_inputs.empty(), true); for (auto &iname : all_inputs) { auto *var = scope.FindVar(iname); - PADDLE_ENFORCE(var != nullptr); - PADDLE_ENFORCE(var->IsType()); + PADDLE_ENFORCE_NOT_NULL(var); + PADDLE_ENFORCE_EQ(var->IsType(), true); auto &dim = var->Get().dims(); if (seq_len == -1) { seq_len = dim[0]; @@ -254,7 +250,7 @@ void RecurrentOp::RunImpl(const framework::Scope &scope, }); } - scopes.Next(); + scopes.ForwardNext(); } } @@ -262,7 +258,7 @@ StepScopes RecurrentOp::CreateStepScopes(const platform::DeviceContext &dev_ctx, const framework::Scope &scope, size_t seq_len) const { auto *var = scope.FindVar(Output(kStepScopes)); - PADDLE_ENFORCE(var != nullptr); + PADDLE_ENFORCE_NOT_NULL(var); return StepScopes(dev_ctx, scope, var->GetMutable(), Attr(kIsTrain), seq_len); } @@ -459,11 +455,11 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope, VLOG(5) << "Link initialize state gradient finished "; } } - scopes.Next(); + scopes.BackwardNext(dev_ctx, const_cast(&scope)); } // Delete the scope of StepScopes auto *var = scope.FindVar(Input(kStepScopes)); - PADDLE_ENFORCE(var != nullptr); + PADDLE_ENFORCE_NOT_NULL(var); auto *step_scopes = var->GetMutable(); ClearStepScopes(dev_ctx, const_cast(&scope), step_scopes); } @@ -472,7 +468,7 @@ StepScopes RecurrentGradOp::CreateStepScopes( const platform::DeviceContext &dev_ctx, const framework::Scope &scope, size_t seq_len) const { auto *var = scope.FindVar(Input(kStepScopes)); - PADDLE_ENFORCE(var != nullptr); + PADDLE_ENFORCE_NOT_NULL(var); return StepScopes(dev_ctx, scope, var->GetMutable(), Attr(kIsTrain), seq_len, true /*is_backward*/); } @@ -491,6 +487,7 @@ std::unordered_set RecurrentGradOp::LocalVarNames( const framework::Scope &scope) const { return this->List2Set(scope.LocalVarNames()); } + std::vector RecurrentGradOp::GradVarLists( const std::vector &var_names) { std::vector retv; @@ -627,25 +624,25 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase { 0, "The Attr(%s) should be empty.", RecurrentBase::kStates); } - PADDLE_ENFORCE(ctx->HasInputs(RecurrentBase::kInputs), - "The input(%s) should not be empty.", - RecurrentBase::kInputs); - PADDLE_ENFORCE(ctx->HasInputs(RecurrentBase::kOutputs), - "The input(%s) should not be empty.", - RecurrentBase::kOutputs); + PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kInputs), true, + "The input(%s) should not be empty.", + RecurrentBase::kInputs); + PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kOutputs), true, + "The input(%s) should not be empty.", + RecurrentBase::kOutputs); // In some case the kInitialStates is empty. if (ctx->HasInputs(RecurrentBase::kInitialStates)) { - PADDLE_ENFORCE(ctx->HasOutputs( - framework::GradVarName(RecurrentBase::kInitialStates)), - "The output of(%s) should not be empty.", - framework::GradVarName(RecurrentBase::kInitialStates)); + PADDLE_ENFORCE_EQ(ctx->HasOutputs(framework::GradVarName( + RecurrentBase::kInitialStates)), + true, "The output of(%s) should not be empty.", + framework::GradVarName(RecurrentBase::kInitialStates)); ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInitialStates), ctx->GetInputsDim(RecurrentBase::kInitialStates)); } - PADDLE_ENFORCE( - ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), + PADDLE_ENFORCE_EQ( + ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), true, "The output of(%s) should not be empty.", framework::GradVarName(RecurrentBase::kInputs)); ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInputs), @@ -653,9 +650,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase { // In some case the kParameters is empty. if (ctx->HasInputs(RecurrentBase::kParameters)) { - PADDLE_ENFORCE( + PADDLE_ENFORCE_EQ( ctx->HasOutputs(framework::GradVarName(RecurrentBase::kParameters)), - "The output of(%s) should not be empty.", + true, "The output of(%s) should not be empty.", framework::GradVarName(RecurrentBase::kParameters)); ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kParameters), ctx->GetInputsDim(RecurrentBase::kParameters)); diff --git a/paddle/fluid/operators/recurrent_op.h b/paddle/fluid/operators/recurrent_op.h index 8da0fcacee2c43a5e24a6638b7c100418f7df904..a4b21448a6057054d1520ce660758cb037667315 100644 --- a/paddle/fluid/operators/recurrent_op.h +++ b/paddle/fluid/operators/recurrent_op.h @@ -25,20 +25,17 @@ limitations under the License. */ namespace paddle { namespace operators { -// StepScopes manages scopes inside RNN. -// StepScopes::CurScope() get the current scope -// StepScopes::ExScope() get the ex-scope, or scope in previous time step. -// StepScopes::Next() move to next time step. +// StepScopes manages the scopes inside Recurrent Op. // // if is_train = False, then -// there are two scopes for the RNN and just support forward. +// there are two scopes for the RNN and just support forward // else // the len(scopes) == seq_len // // if is_backward = True, then -// reversely access scopes +// reversely access scopes, delete useless ex-scope // else -// access scopes from begin to end. +// access scopes from beginning to end class StepScopes { public: StepScopes(const platform::DeviceContext &dev_ctx, @@ -46,11 +43,19 @@ class StepScopes { std::vector *scopes, bool is_train, size_t seq_len, bool is_backward = false); + // Get the current scope framework::Scope &CurScope(); + // Get the ex-scope, which is the scope in previous time step framework::Scope &ExScope(); - void Next(); + // Move to next time step when forwarding + void ForwardNext(); + + // Delete ex-scope after using it, then move to next time step when + // backwarding + void BackwardNext(const platform::DeviceContext &dev_ctx, + framework::Scope *parent_scope); private: framework::Scope &GetScope(size_t scope_id) const; @@ -154,7 +159,7 @@ class RecurrentBase : public framework::OperatorBase { if (is_backward && src_var == nullptr) { return; } - PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name); + PADDLE_ENFORCE_NOT_NULL(src_var, "%s is not found.", src_var_name); auto &src_tensor = src_var->Get(); auto *dst_var = dst_scope->Var(dst_var_name); @@ -173,9 +178,9 @@ class RecurrentBase : public framework::OperatorBase { return; } auto *src_var = src_scope.FindVar(src_var_name); - PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name); + PADDLE_ENFORCE_NOT_NULL(src_var, "%s is not found.", src_var_name); auto &src_tensor = src_var->Get(); - PADDLE_ENFORCE(dst_var != nullptr, "%s is not found.", dst_var_name); + PADDLE_ENFORCE_NOT_NULL(dst_var, "%s is not found.", dst_var_name); auto *dst_tensor = dst_var->GetMutable(); callback(src_tensor, dst_tensor); }