diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 7c2467310fc3ecaa863dc15b96f2d4b88f281dfb..df9b53d6a4045489e6f402fdca91ec0d758af0ea 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -78,10 +78,11 @@ void ExecutorPrepareContext::PrepareUnusedVars( // If gc is enabled and block size > 1 if (prog_.Size() > 1) { operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( - block_id_, ops_); - operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(block_id_, ops_); + prog_, block_id_, ops_); + operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(prog_, block_id_, + ops_); operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( - block_id_, ops_); + prog_, block_id_, ops_); } unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars); } diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc index 5bceb4e8346ae04945da72ce248a187adb1288b3..56a658d4220add287f95f7b596c6a013ee64d229 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc @@ -48,7 +48,7 @@ class ConditionalOpEagerDeletionPass : public Pass { auto &ifelse_ops = ops_pair.second.first; auto &ifelse_grad_ops = ops_pair.second.second; operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( - ifelse_ops, ifelse_grad_ops); + graph->OriginProgram(), ifelse_ops, ifelse_grad_ops); } } }; diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc index 40e07ce8b6d1e52de3ab87cad96691cae3dd37b6..6077069ea747a60b5989c5da373536e6654b2b74 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc @@ -40,7 +40,8 @@ void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const { // Prepare safe eager deletion on different devices because the garbage // collection may be different across devices OpAndGradOpPair &op_pair = entry.second; - PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(&op_pair); + PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( + graph->OriginProgram(), &op_pair); } } diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc index 63f996ade5648c80ab3e505ca9cddd80f93a7ef4..da0da4c7125953d386fbd4d14bc2607837616cc3 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc @@ -47,7 +47,7 @@ class WhileOpEagerDeletionPass : public ir::Pass { auto &while_ops = ops_pair.second.first; auto &while_grad_ops = ops_pair.second.second; operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( - while_ops, while_grad_ops); + graph->OriginProgram(), while_ops, while_grad_ops); } } }; diff --git a/paddle/fluid/operators/controlflow/conditional_block_op_helper.cc b/paddle/fluid/operators/controlflow/conditional_block_op_helper.cc index 357a9d93b69a4758359e9a68cdec7c286482cc1b..13a00c852a27da2b75056ffbcdc0873ee553e2a8 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op_helper.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op_helper.cc @@ -29,16 +29,12 @@ static bool IsMatchedConditionalBlockOpAndConditionalBlockGradOp( } static void FindAllConditionalBlockAndConditionalBlockGradOp( - std::vector *fwd_ops, std::vector *bwd_ops) { + const framework::ProgramDesc &program, std::vector *fwd_ops, + std::vector *bwd_ops) { PADDLE_ENFORCE_GE(fwd_ops->size(), bwd_ops->size()); - if (fwd_ops->empty()) return; - - const auto *program = - fwd_ops->front().Attr("sub_block")->Program(); - - for (size_t i = 1; i < program->Size(); ++i) { - auto &block = program->Block(i); + for (size_t i = 1; i < program.Size(); ++i) { + auto &block = program.Block(i); for (size_t j = 0; j < block.OpSize(); ++j) { auto *op = block.Op(j); if (op->Type() == "conditional_block") { @@ -86,9 +82,10 @@ static void SetSkipVarsForConditionalBlockOp(OpVariant *fwd_op, } static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl( - std::vector *ifelse_ops, + const framework::ProgramDesc &program, std::vector *ifelse_ops, std::vector *ifelse_grad_ops) { - FindAllConditionalBlockAndConditionalBlockGradOp(ifelse_ops, ifelse_grad_ops); + FindAllConditionalBlockAndConditionalBlockGradOp(program, ifelse_ops, + ifelse_grad_ops); VLOG(2) << "Found conditional_block op num: " << ifelse_ops->size() << ", conditional_block_grad op num: " << ifelse_grad_ops->size(); @@ -121,7 +118,7 @@ static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl( } void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( - int block_id, + const framework::ProgramDesc &program, int block_id, const std::vector> &all_ops) { // If block_id is not 0, returns // This is because all conditional_block_ops and conditional_block_grad_ops @@ -143,11 +140,12 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( } } - PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(&fwd_ops, - &bwd_ops); + PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl( + program, &fwd_ops, &bwd_ops); } void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( + const framework::ProgramDesc &program, const std::vector &ifelse_ops, const std::vector &ifelse_grad_ops) { std::vector fwd_ops, bwd_ops; @@ -161,8 +159,8 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( bwd_ops.emplace_back(op); } - PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(&fwd_ops, - &bwd_ops); + PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl( + program, &fwd_ops, &bwd_ops); } } // namespace operators diff --git a/paddle/fluid/operators/controlflow/conditional_block_op_helper.h b/paddle/fluid/operators/controlflow/conditional_block_op_helper.h index 572b6ac4e466fd070f3955b0c2379bd1c67d0825..f7dfba6f364e197a97cc5e061e42cd5cc84309db 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op_helper.h +++ b/paddle/fluid/operators/controlflow/conditional_block_op_helper.h @@ -23,10 +23,11 @@ namespace paddle { namespace operators { void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( - int block_id, + const framework::ProgramDesc &program, int block_id, const std::vector> &all_ops); void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( + const framework::ProgramDesc &program, const std::vector &ifelse_ops, const std::vector &ifelse_grad_ops); diff --git a/paddle/fluid/operators/controlflow/recurrent_op_helper.cc b/paddle/fluid/operators/controlflow/recurrent_op_helper.cc index 6925086679b2f2926755e6b0b21ef43d3f62316c..d2bb68272dff46e36349baf23fff88433950b3fd 100644 --- a/paddle/fluid/operators/controlflow/recurrent_op_helper.cc +++ b/paddle/fluid/operators/controlflow/recurrent_op_helper.cc @@ -65,7 +65,8 @@ static void AddSkipVars(const OpVariant &op, const Container &skip_vars) { // Find all ops and grad ops with given type name. The ops and grad ops // may locate in different blocks so we should traverse all blocks in the // program and find them out -static void FindAllOpAndGradOp(OpAndGradOpPair *op_and_grad_op, +static void FindAllOpAndGradOp(const framework::ProgramDesc &program, + OpAndGradOpPair *op_and_grad_op, const std::string &type_name, const std::string &backward_type_name) { OpVariantSet &ops = op_and_grad_op->first; @@ -74,14 +75,8 @@ static void FindAllOpAndGradOp(OpAndGradOpPair *op_and_grad_op, PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(), "There are extra grad ops in the graph or program"); - if (ops.empty()) return; - - const auto *program = - ops.begin() - ->Attr(RecurrentBase::kStepBlock) - ->Program(); - for (size_t i = 1; i < program->Size(); ++i) { - auto &block = program->Block(i); + for (size_t i = 1; i < program.Size(); ++i) { + auto &block = program.Block(i); for (size_t j = 0; j < block.OpSize(); ++j) { auto *op = block.Op(j); if (op->Type() == type_name) { @@ -201,7 +196,7 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr( } void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( - int block_id, + const framework::ProgramDesc &program, int block_id, const std::vector> &all_ops) { // If block_id is not 0, returns @@ -224,13 +219,13 @@ void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( op_pair.second.emplace(op.get()); } } - PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(&op_pair); + PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(program, &op_pair); } void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( - OpAndGradOpPair *op_pair) { + const framework::ProgramDesc &program, OpAndGradOpPair *op_pair) { // Find all ops and grad ops at all blocks - FindAllOpAndGradOp(op_pair, "recurrent", "recurrent_grad"); + FindAllOpAndGradOp(program, op_pair, "recurrent", "recurrent_grad"); OpVariantSet &recurrent_ops = op_pair->first; OpVariantSet &recurrent_grad_ops = op_pair->second; diff --git a/paddle/fluid/operators/controlflow/recurrent_op_helper.h b/paddle/fluid/operators/controlflow/recurrent_op_helper.h index b1e6e662c082c6f78229344a11446bd7863a7d84..aacca0762ca1d45634d36da572448dae7e9fe195 100644 --- a/paddle/fluid/operators/controlflow/recurrent_op_helper.h +++ b/paddle/fluid/operators/controlflow/recurrent_op_helper.h @@ -37,14 +37,14 @@ using OpAndGradOpPair = std::pair; // recurrent_grad ops at block 0 and the function will find all recurrent and // recurrent_grad ops across blocks. void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( - OpAndGradOpPair *op_pair); + const framework::ProgramDesc &program, OpAndGradOpPair *op_pair); // Set vars to skip eager deletion on input recurrent and recurrent_grad for // preparing safe eager deletion. The input block_id must be 0 and caller can // input all ops in the block. The function will find all recurrent and // recurrent_grad ops across blocks. void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( - int block_id, + const framework::ProgramDesc &program, int block_id, const std::vector> &all_ops); diff --git a/paddle/fluid/operators/controlflow/while_op_helper.cc b/paddle/fluid/operators/controlflow/while_op_helper.cc index 009bc5796ce2e94de94fd400c2012752a002324e..8f1e3f60927abc23c18c208efbd77715e40136bc 100644 --- a/paddle/fluid/operators/controlflow/while_op_helper.cc +++ b/paddle/fluid/operators/controlflow/while_op_helper.cc @@ -100,16 +100,12 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op, // Find all while_ops and while_grad_ops in the graph or program // The while_grad_op and while_op may located in different blocks // So we should traverse all blocks in the program and find them out. -static void FindAllWhileAndWhileGradOp(std::vector *while_ops, +static void FindAllWhileAndWhileGradOp(const framework::ProgramDesc &program, + std::vector *while_ops, std::vector *while_grad_ops) { PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size()); - - if (while_ops->empty()) return; - - const auto *program = - while_ops->front().Attr(kStepBlock)->Program(); - for (size_t i = 1; i < program->Size(); ++i) { - auto &block = program->Block(i); + for (size_t i = 1; i < program.Size(); ++i) { + auto &block = program.Block(i); for (size_t j = 0; j < block.OpSize(); ++j) { auto *op = block.Op(j); if (op->Type() == "while") { @@ -125,8 +121,9 @@ static void FindAllWhileAndWhileGradOp(std::vector *while_ops, } static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl( - std::vector *while_ops, std::vector *while_grad_ops) { - FindAllWhileAndWhileGradOp(while_ops, while_grad_ops); + const framework::ProgramDesc &program, std::vector *while_ops, + std::vector *while_grad_ops) { + FindAllWhileAndWhileGradOp(program, while_ops, while_grad_ops); VLOG(2) << "Found while op num: " << while_ops->size() << ", while grad op num: " << while_grad_ops->size(); @@ -155,7 +152,7 @@ static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl( } void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( - int block_id, + const framework::ProgramDesc &program, int block_id, const std::vector> &all_ops) { // If block_id is not 0, returns // This is because all while_ops and while_grad_ops in the whole program @@ -176,10 +173,12 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( bwd_ops.emplace_back(op.get()); } } - PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops); + PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops, + &bwd_ops); } void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( + const framework::ProgramDesc &program, const std::vector &while_ops, const std::vector &while_grad_ops) { std::vector fwd_ops, bwd_ops; @@ -193,7 +192,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( bwd_ops.emplace_back(op); } - PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops); + PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops, + &bwd_ops); } } // namespace operators diff --git a/paddle/fluid/operators/controlflow/while_op_helper.h b/paddle/fluid/operators/controlflow/while_op_helper.h index 456ba8642b9bd32a1236d112cc8b387ae6a279d3..e2cfece658088b8e8b74ae52da4b43b21c01127c 100644 --- a/paddle/fluid/operators/controlflow/while_op_helper.h +++ b/paddle/fluid/operators/controlflow/while_op_helper.h @@ -32,10 +32,11 @@ static constexpr char kOutputs[] = "Out"; static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars"; void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( - int block_id, + const framework::ProgramDesc &program, int block_id, const std::vector> &all_ops); void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( + const framework::ProgramDesc &program, const std::vector &while_ops, const std::vector &while_grad_ops);