From 2704479bb22fda11225b9d9ccaf757648c70fbb6 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Thu, 30 May 2019 10:24:33 +0800 Subject: [PATCH] Optimize recurrent_op using Prepare and RunPreparedContext, avoiding create operators in every iter. (#17689) test=develop --- paddle/fluid/operators/recurrent_op.cc | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/recurrent_op.cc b/paddle/fluid/operators/recurrent_op.cc index ac432f4dd03..b3bb1abf4da 100644 --- a/paddle/fluid/operators/recurrent_op.cc +++ b/paddle/fluid/operators/recurrent_op.cc @@ -272,6 +272,9 @@ class RecurrentOp : public RecurrentBase { auto *block = Attr(kStepBlock); auto *program = block->Program(); + auto ctx = executor.Prepare( + *program, block->ID(), std::vector() /*skip_ref_cnt_vars*/, + true /*force_disable_gc*/); for (size_t i = 0; i < seq_len; ++i) { size_t seq_offset = reverse ? seq_len - i - 1 : i; @@ -305,10 +308,9 @@ class RecurrentOp : public RecurrentBase { } // Every inputs are linked now, execute! - executor.Run(*program, &cur_scope, block->ID(), - false /*create_local_scope*/, true /*create_vars*/, - std::vector() /*skip_ref_cnt_vars*/, - true /*force_disable_gc*/); + executor.RunPreparedContext(ctx.get(), &cur_scope, + false /*create_local_scope*/, + true /*create_vars*/, true /* keep_kids */); // Copy inside::output -> outside::output // outside::output[seq_offset: seq_offset + 1] = inside::output @@ -366,6 +368,9 @@ class RecurrentGradOp : public RecurrentBase { framework::Executor executor(place); auto *block = Attr(kStepBlock); auto *program = block->Program(); + auto ctx = executor.Prepare( + *program, block->ID(), std::vector() /*skip_ref_cnt_vars*/, + true /*force_disable_gc*/); for (size_t step_id = 0; step_id < seq_len; ++step_id) { size_t seq_offset = reverse ? step_id : seq_len - step_id - 1; @@ -423,10 +428,9 @@ class RecurrentGradOp : public RecurrentBase { VLOG(5) << "Recurrent memory linking finished "; // Run step block with cur_scope - executor.Run(*program, &cur_scope, block->ID(), - false /*create_local_scope*/, true /*create_vars*/, - std::vector() /*skip_ref_cnt_vars*/, - true /*force_disable_gc*/); + executor.RunPreparedContext(ctx.get(), &cur_scope, + false /*create_local_scope*/, + true /*create_vars*/, true /* keep_kids */); VLOG(5) << "executor.Run finished "; -- GitLab