From 3fd3b663a84eaf7cd59092a27fd3c7758103721b Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 18 Sep 2019 10:36:04 +0800 Subject: [PATCH] fix gc bug in controlflow ops, test=develop (#19827) --- paddle/fluid/framework/executor.cc | 7 +++-- ...onditional_block_op_eager_deletion_pass.cc | 2 +- .../recurrent_op_eager_deletion_pass.cc | 3 +- .../while_op_eager_deletion_pass.cc | 2 +- .../conditional_block_op_helper.cc | 28 +++++++++---------- .../controlflow/conditional_block_op_helper.h | 3 +- .../controlflow/recurrent_op_helper.cc | 21 ++++++-------- .../controlflow/recurrent_op_helper.h | 4 +-- .../operators/controlflow/while_op_helper.cc | 26 ++++++++--------- .../operators/controlflow/while_op_helper.h | 3 +- 10 files changed, 48 insertions(+), 51 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 7c2467310f..df9b53d6a4 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -78,10 +78,11 @@ void ExecutorPrepareContext::PrepareUnusedVars( // If gc is enabled and block size > 1 if (prog_.Size() > 1) { operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( - block_id_, ops_); - operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(block_id_, ops_); + prog_, block_id_, ops_); + operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(prog_, block_id_, + ops_); operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( - block_id_, ops_); + prog_, block_id_, ops_); } unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars); } diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc index 5bceb4e834..56a658d422 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc @@ -48,7 +48,7 @@ class ConditionalOpEagerDeletionPass : public Pass { auto &ifelse_ops = ops_pair.second.first; auto &ifelse_grad_ops = ops_pair.second.second; operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( - ifelse_ops, ifelse_grad_ops); + graph->OriginProgram(), ifelse_ops, ifelse_grad_ops); } } }; diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc index 40e07ce8b6..6077069ea7 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc @@ -40,7 +40,8 @@ void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const { // Prepare safe eager deletion on different devices because the garbage // collection may be different across devices OpAndGradOpPair &op_pair = entry.second; - PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(&op_pair); + PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( + graph->OriginProgram(), &op_pair); } } diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc index 63f996ade5..da0da4c712 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc @@ -47,7 +47,7 @@ class WhileOpEagerDeletionPass : public ir::Pass { auto &while_ops = ops_pair.second.first; auto &while_grad_ops = ops_pair.second.second; operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( - while_ops, while_grad_ops); + graph->OriginProgram(), while_ops, while_grad_ops); } } }; diff --git a/paddle/fluid/operators/controlflow/conditional_block_op_helper.cc b/paddle/fluid/operators/controlflow/conditional_block_op_helper.cc index 357a9d93b6..13a00c852a 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op_helper.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op_helper.cc @@ -29,16 +29,12 @@ static bool IsMatchedConditionalBlockOpAndConditionalBlockGradOp( } static void FindAllConditionalBlockAndConditionalBlockGradOp( - std::vector *fwd_ops, std::vector *bwd_ops) { + const framework::ProgramDesc &program, std::vector *fwd_ops, + std::vector *bwd_ops) { PADDLE_ENFORCE_GE(fwd_ops->size(), bwd_ops->size()); - if (fwd_ops->empty()) return; - - const auto *program = - fwd_ops->front().Attr("sub_block")->Program(); - - for (size_t i = 1; i < program->Size(); ++i) { - auto &block = program->Block(i); + for (size_t i = 1; i < program.Size(); ++i) { + auto &block = program.Block(i); for (size_t j = 0; j < block.OpSize(); ++j) { auto *op = block.Op(j); if (op->Type() == "conditional_block") { @@ -86,9 +82,10 @@ static void SetSkipVarsForConditionalBlockOp(OpVariant *fwd_op, } static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl( - std::vector *ifelse_ops, + const framework::ProgramDesc &program, std::vector *ifelse_ops, std::vector *ifelse_grad_ops) { - FindAllConditionalBlockAndConditionalBlockGradOp(ifelse_ops, ifelse_grad_ops); + FindAllConditionalBlockAndConditionalBlockGradOp(program, ifelse_ops, + ifelse_grad_ops); VLOG(2) << "Found conditional_block op num: " << ifelse_ops->size() << ", conditional_block_grad op num: " << ifelse_grad_ops->size(); @@ -121,7 +118,7 @@ static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl( } void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( - int block_id, + const framework::ProgramDesc &program, int block_id, const std::vector> &all_ops) { // If block_id is not 0, returns // This is because all conditional_block_ops and conditional_block_grad_ops @@ -143,11 +140,12 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( } } - PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(&fwd_ops, - &bwd_ops); + PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl( + program, &fwd_ops, &bwd_ops); } void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( + const framework::ProgramDesc &program, const std::vector &ifelse_ops, const std::vector &ifelse_grad_ops) { std::vector fwd_ops, bwd_ops; @@ -161,8 +159,8 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( bwd_ops.emplace_back(op); } - PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(&fwd_ops, - &bwd_ops); + PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl( + program, &fwd_ops, &bwd_ops); } } // namespace operators diff --git a/paddle/fluid/operators/controlflow/conditional_block_op_helper.h b/paddle/fluid/operators/controlflow/conditional_block_op_helper.h index 572b6ac4e4..f7dfba6f36 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op_helper.h +++ b/paddle/fluid/operators/controlflow/conditional_block_op_helper.h @@ -23,10 +23,11 @@ namespace paddle { namespace operators { void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( - int block_id, + const framework::ProgramDesc &program, int block_id, const std::vector> &all_ops); void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( + const framework::ProgramDesc &program, const std::vector &ifelse_ops, const std::vector &ifelse_grad_ops); diff --git a/paddle/fluid/operators/controlflow/recurrent_op_helper.cc b/paddle/fluid/operators/controlflow/recurrent_op_helper.cc index 6925086679..d2bb68272d 100644 --- a/paddle/fluid/operators/controlflow/recurrent_op_helper.cc +++ b/paddle/fluid/operators/controlflow/recurrent_op_helper.cc @@ -65,7 +65,8 @@ static void AddSkipVars(const OpVariant &op, const Container &skip_vars) { // Find all ops and grad ops with given type name. The ops and grad ops // may locate in different blocks so we should traverse all blocks in the // program and find them out -static void FindAllOpAndGradOp(OpAndGradOpPair *op_and_grad_op, +static void FindAllOpAndGradOp(const framework::ProgramDesc &program, + OpAndGradOpPair *op_and_grad_op, const std::string &type_name, const std::string &backward_type_name) { OpVariantSet &ops = op_and_grad_op->first; @@ -74,14 +75,8 @@ static void FindAllOpAndGradOp(OpAndGradOpPair *op_and_grad_op, PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(), "There are extra grad ops in the graph or program"); - if (ops.empty()) return; - - const auto *program = - ops.begin() - ->Attr(RecurrentBase::kStepBlock) - ->Program(); - for (size_t i = 1; i < program->Size(); ++i) { - auto &block = program->Block(i); + for (size_t i = 1; i < program.Size(); ++i) { + auto &block = program.Block(i); for (size_t j = 0; j < block.OpSize(); ++j) { auto *op = block.Op(j); if (op->Type() == type_name) { @@ -201,7 +196,7 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr( } void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( - int block_id, + const framework::ProgramDesc &program, int block_id, const std::vector> &all_ops) { // If block_id is not 0, returns @@ -224,13 +219,13 @@ void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( op_pair.second.emplace(op.get()); } } - PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(&op_pair); + PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(program, &op_pair); } void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( - OpAndGradOpPair *op_pair) { + const framework::ProgramDesc &program, OpAndGradOpPair *op_pair) { // Find all ops and grad ops at all blocks - FindAllOpAndGradOp(op_pair, "recurrent", "recurrent_grad"); + FindAllOpAndGradOp(program, op_pair, "recurrent", "recurrent_grad"); OpVariantSet &recurrent_ops = op_pair->first; OpVariantSet &recurrent_grad_ops = op_pair->second; diff --git a/paddle/fluid/operators/controlflow/recurrent_op_helper.h b/paddle/fluid/operators/controlflow/recurrent_op_helper.h index b1e6e662c0..aacca0762c 100644 --- a/paddle/fluid/operators/controlflow/recurrent_op_helper.h +++ b/paddle/fluid/operators/controlflow/recurrent_op_helper.h @@ -37,14 +37,14 @@ using OpAndGradOpPair = std::pair; // recurrent_grad ops at block 0 and the function will find all recurrent and // recurrent_grad ops across blocks. void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( - OpAndGradOpPair *op_pair); + const framework::ProgramDesc &program, OpAndGradOpPair *op_pair); // Set vars to skip eager deletion on input recurrent and recurrent_grad for // preparing safe eager deletion. The input block_id must be 0 and caller can // input all ops in the block. The function will find all recurrent and // recurrent_grad ops across blocks. void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( - int block_id, + const framework::ProgramDesc &program, int block_id, const std::vector> &all_ops); diff --git a/paddle/fluid/operators/controlflow/while_op_helper.cc b/paddle/fluid/operators/controlflow/while_op_helper.cc index 009bc5796c..8f1e3f6092 100644 --- a/paddle/fluid/operators/controlflow/while_op_helper.cc +++ b/paddle/fluid/operators/controlflow/while_op_helper.cc @@ -100,16 +100,12 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op, // Find all while_ops and while_grad_ops in the graph or program // The while_grad_op and while_op may located in different blocks // So we should traverse all blocks in the program and find them out. -static void FindAllWhileAndWhileGradOp(std::vector *while_ops, +static void FindAllWhileAndWhileGradOp(const framework::ProgramDesc &program, + std::vector *while_ops, std::vector *while_grad_ops) { PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size()); - - if (while_ops->empty()) return; - - const auto *program = - while_ops->front().Attr(kStepBlock)->Program(); - for (size_t i = 1; i < program->Size(); ++i) { - auto &block = program->Block(i); + for (size_t i = 1; i < program.Size(); ++i) { + auto &block = program.Block(i); for (size_t j = 0; j < block.OpSize(); ++j) { auto *op = block.Op(j); if (op->Type() == "while") { @@ -125,8 +121,9 @@ static void FindAllWhileAndWhileGradOp(std::vector *while_ops, } static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl( - std::vector *while_ops, std::vector *while_grad_ops) { - FindAllWhileAndWhileGradOp(while_ops, while_grad_ops); + const framework::ProgramDesc &program, std::vector *while_ops, + std::vector *while_grad_ops) { + FindAllWhileAndWhileGradOp(program, while_ops, while_grad_ops); VLOG(2) << "Found while op num: " << while_ops->size() << ", while grad op num: " << while_grad_ops->size(); @@ -155,7 +152,7 @@ static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl( } void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( - int block_id, + const framework::ProgramDesc &program, int block_id, const std::vector> &all_ops) { // If block_id is not 0, returns // This is because all while_ops and while_grad_ops in the whole program @@ -176,10 +173,12 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( bwd_ops.emplace_back(op.get()); } } - PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops); + PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops, + &bwd_ops); } void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( + const framework::ProgramDesc &program, const std::vector &while_ops, const std::vector &while_grad_ops) { std::vector fwd_ops, bwd_ops; @@ -193,7 +192,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( bwd_ops.emplace_back(op); } - PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops); + PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops, + &bwd_ops); } } // namespace operators diff --git a/paddle/fluid/operators/controlflow/while_op_helper.h b/paddle/fluid/operators/controlflow/while_op_helper.h index 456ba8642b..e2cfece658 100644 --- a/paddle/fluid/operators/controlflow/while_op_helper.h +++ b/paddle/fluid/operators/controlflow/while_op_helper.h @@ -32,10 +32,11 @@ static constexpr char kOutputs[] = "Out"; static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars"; void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( - int block_id, + const framework::ProgramDesc &program, int block_id, const std::vector> &all_ops); void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( + const framework::ProgramDesc &program, const std::vector &while_ops, const std::vector &while_grad_ops); -- GitLab