未验证 提交 2704479b 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize recurrent_op using Prepare and RunPreparedContext, avoiding create...

Optimize recurrent_op using Prepare and RunPreparedContext, avoiding create operators in every iter. (#17689)

test=develop
上级 9b998764
...@@ -272,6 +272,9 @@ class RecurrentOp : public RecurrentBase { ...@@ -272,6 +272,9 @@ class RecurrentOp : public RecurrentBase {
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
auto ctx = executor.Prepare(
*program, block->ID(), std::vector<std::string>() /*skip_ref_cnt_vars*/,
true /*force_disable_gc*/);
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
size_t seq_offset = reverse ? seq_len - i - 1 : i; size_t seq_offset = reverse ? seq_len - i - 1 : i;
...@@ -305,10 +308,9 @@ class RecurrentOp : public RecurrentBase { ...@@ -305,10 +308,9 @@ class RecurrentOp : public RecurrentBase {
} }
// Every inputs are linked now, execute! // Every inputs are linked now, execute!
executor.Run(*program, &cur_scope, block->ID(), executor.RunPreparedContext(ctx.get(), &cur_scope,
false /*create_local_scope*/, true /*create_vars*/, false /*create_local_scope*/,
std::vector<std::string>() /*skip_ref_cnt_vars*/, true /*create_vars*/, true /* keep_kids */);
true /*force_disable_gc*/);
// Copy inside::output -> outside::output // Copy inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output // outside::output[seq_offset: seq_offset + 1] = inside::output
...@@ -366,6 +368,9 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -366,6 +368,9 @@ class RecurrentGradOp : public RecurrentBase {
framework::Executor executor(place); framework::Executor executor(place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
auto ctx = executor.Prepare(
*program, block->ID(), std::vector<std::string>() /*skip_ref_cnt_vars*/,
true /*force_disable_gc*/);
for (size_t step_id = 0; step_id < seq_len; ++step_id) { for (size_t step_id = 0; step_id < seq_len; ++step_id) {
size_t seq_offset = reverse ? step_id : seq_len - step_id - 1; size_t seq_offset = reverse ? step_id : seq_len - step_id - 1;
...@@ -423,10 +428,9 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -423,10 +428,9 @@ class RecurrentGradOp : public RecurrentBase {
VLOG(5) << "Recurrent memory linking finished "; VLOG(5) << "Recurrent memory linking finished ";
// Run step block with cur_scope // Run step block with cur_scope
executor.Run(*program, &cur_scope, block->ID(), executor.RunPreparedContext(ctx.get(), &cur_scope,
false /*create_local_scope*/, true /*create_vars*/, false /*create_local_scope*/,
std::vector<std::string>() /*skip_ref_cnt_vars*/, true /*create_vars*/, true /* keep_kids */);
true /*force_disable_gc*/);
VLOG(5) << "executor.Run finished "; VLOG(5) << "executor.Run finished ";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册