From 3d4e8268c6aba121f75782279744a5a5e111ab13 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Sun, 19 May 2019 08:01:48 -0500 Subject: [PATCH] fix recurrent fwd bug when no backward and scope clear (#17460) --- paddle/fluid/operators/recurrent_op.cc | 56 +++++++++++++++++--------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/recurrent_op.cc b/paddle/fluid/operators/recurrent_op.cc index 6ead10c9987..ac432f4dd03 100644 --- a/paddle/fluid/operators/recurrent_op.cc +++ b/paddle/fluid/operators/recurrent_op.cc @@ -37,6 +37,20 @@ constexpr char kInitStateGrads[] = "initial_states" GRAD_SUFFIX; using StepScopeVar = std::vector; +static void ClearStepScopes(const platform::DeviceContext &dev_ctx, + framework::Scope *parent_scope, + StepScopeVar *step_scopes) { + if (step_scopes->empty()) return; + + dev_ctx.Wait(); + + for (auto *sub_scope : *step_scopes) { + parent_scope->DeleteScope(sub_scope); + } + + step_scopes->clear(); +} + // StepScopes manages scopes inside RNN. // StepScopes::CurScope() get the current scope // StepScopes::ExScope() get the ex-scope, or scope in previous time step. @@ -53,7 +67,8 @@ using StepScopeVar = std::vector; // access scopes from begin to end. class StepScopes { public: - StepScopes(const framework::Scope &parent, StepScopeVar *scopes, + StepScopes(const platform::DeviceContext &dev_ctx, + const framework::Scope &parent, StepScopeVar *scopes, bool is_train, size_t seq_len, bool is_backward = false) : counter_(is_backward ? seq_len - 1 : 0UL), scopes_(scopes), @@ -63,7 +78,7 @@ class StepScopes { PADDLE_ENFORCE(is_train || !is_backward, "Cannot backward when is not training"); if (!is_backward_) { - PADDLE_ENFORCE(scopes->empty()); + ClearStepScopes(dev_ctx, const_cast(&parent), scopes); scopes->reserve(static_cast(num_step_scopes)); for (size_t i = 0; i < num_step_scopes; ++i) { scopes->emplace_back(&parent.NewScope()); @@ -244,14 +259,15 @@ class RecurrentOp : public RecurrentBase { const platform::Place &place) const override { bool has_state = Attr(kHasStates); auto seq_len = static_cast(this->GetSequenceLength(scope)); - VLOG(3) << "Static RNN input sequence length = " << seq_len; - StepScopes scopes = CreateStepScopes(scope, seq_len); - auto reverse = Attr(kReverse); // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); + VLOG(3) << "Static RNN input sequence length = " << seq_len; + StepScopes scopes = CreateStepScopes(dev_ctx, scope, seq_len); + auto reverse = Attr(kReverse); + framework::Executor executor(place); auto *block = Attr(kStepBlock); @@ -316,11 +332,12 @@ class RecurrentOp : public RecurrentBase { } private: - StepScopes CreateStepScopes(const framework::Scope &scope, + StepScopes CreateStepScopes(const platform::DeviceContext &dev_ctx, + const framework::Scope &scope, size_t seq_len) const { auto *var = scope.FindVar(Output(kStepScopes)); PADDLE_ENFORCE(var != nullptr); - return StepScopes(scope, var->GetMutable(), + return StepScopes(dev_ctx, scope, var->GetMutable(), Attr(kIsTrain), seq_len); } }; @@ -338,17 +355,18 @@ class RecurrentGradOp : public RecurrentBase { const platform::Place &place) const override { bool has_state = Attr(kHasStates); const size_t seq_len = static_cast(GetSequenceLength(scope)); - StepScopes scopes = CreateStepScopes(scope, seq_len); + + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + StepScopes scopes = CreateStepScopes(dev_ctx, scope, seq_len); auto reverse = Attr(kReverse); framework::Executor executor(place); auto *block = Attr(kStepBlock); auto *program = block->Program(); - // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - for (size_t step_id = 0; step_id < seq_len; ++step_id) { size_t seq_offset = reverse ? step_id : seq_len - step_id - 1; VLOG(3) << "Recurrent backward operate at the time step " << seq_offset; @@ -501,22 +519,20 @@ class RecurrentGradOp : public RecurrentBase { scopes.Next(); } // Delete the scope of StepScopes - dev_ctx.Wait(); auto *var = scope.FindVar(Input(kStepScopes)); PADDLE_ENFORCE(var != nullptr); - auto step_scopes = var->GetMutable(); - for (auto *sub_scope : *step_scopes) { - const_cast(scope).DeleteScope(sub_scope); - } - step_scopes->clear(); + auto *step_scopes = var->GetMutable(); + ClearStepScopes(dev_ctx, const_cast(&scope), + step_scopes); } private: - StepScopes CreateStepScopes(const framework::Scope &scope, + StepScopes CreateStepScopes(const platform::DeviceContext &dev_ctx, + const framework::Scope &scope, size_t seq_len) const { auto *var = scope.FindVar(Input(kStepScopes)); PADDLE_ENFORCE(var != nullptr); - return StepScopes(scope, var->GetMutable(), + return StepScopes(dev_ctx, scope, var->GetMutable(), Attr(kIsTrain), seq_len, true /*is_backward*/); } -- GitLab